paint-brush
使用 Wav2vec2 第 1 部分:微调 XLS-R 以实现自动语音识别经过@pictureinthenoise
1,799 讀數
1,799 讀數

使用 Wav2vec2 第 1 部分:微调 XLS-R 以实现自动语音识别

经过 Picture in the Noise29m2024/05/04
Read on Terminal Reader

太長; 讀書

本指南介绍了对 Meta AI 的 wav2vec2 XLS-R 模型进行微调以实现自动语音识别 (ASR) 的步骤。本指南包含有关如何构建可用于微调模型的 Kaggle Notebook 的分步说明。该模型是在智利西班牙语数据集上进行训练的。
featured image - 使用 Wav2vec2 第 1 部分:微调 XLS-R 以实现自动语音识别
Picture in the Noise HackerNoon profile picture
0-item
1-item

介绍

Meta AI 于 2021 年底推出了wav2vec2 XLS-R(“XLS-R”)。 XLS-R 是一种用于跨语言语音表示学习的机器学习(“ML”)模型;它基于 128 种语言的 40 多万小时公开语音音频进行训练。该模型发布后,超越了 Meta AI 的XLSR-53跨语言模型,后者基于 53 种语言的约 5 万小时语音音频进行训练。


本指南介绍了使用Kaggle Notebook对 XLS-R 进行自动语音识别 (ASR) 微调的步骤。该模型将针对智利西班牙语进行微调,但您可以按照一般步骤对您想要的不同语言的 XLS-R 进行微调。


将在配套教程中介绍如何在经过微调的模型上运行推理,这使得本指南成为两部分中的第一部分。由于本微调指南有点长,我决定创建单独的推理专用指南。


假设您具有 ML 背景并了解基本的 ASR 概念。初学者可能难以理解构建步骤。

XLS-R 的一些背景知识

2020 年推出的原始 wav2vec2 模型是在 960 小时的Librispeech数据集语音音频和约 53,200 小时的LibriVox数据集语音音频上进行预训练的。发布时,有两种模型大小可供选择:具有 9500 万个参数的BASE模型和具有 3.17 亿个参数的LARGE模型。


另一方面,XLS-R 已对来自 5 个数据集的多语言语音音频进行了预训练:


  • VoxPopuli :欧洲议会用 23 种欧洲语言发表的议会演讲总计约 372,000 小时的音频。
  • 多语言 Librispeech :总计约 50,000 小时的八种欧洲语言的语音音频,其中大部分(约 44,000 小时)的音频数据是英语。
  • CommonVoice :总计约 60 种语言的 7,000 小时语音音频。
  • VoxLingua107 :基于 YouTube 内容的总计约 6,600 小时的涵盖 107 种语言的语音音频。
  • BABEL :基于电话对话语音,总计约 1,100 小时的 17 种非洲和亚洲语言的语音音频。


XLS-R 共有 3 个模型: XLS-R (0.3B)具有 3 亿个参数、 XLS-R (1B)具有 10 亿个参数, XLS-R (2B)具有 20 亿个参数。本指南将使用 XLS-R (0.3B) 模型。

方法

关于如何微调wav2vev2模型,有很多很棒的文章,也许这篇文章可以说是某种“黄金标准”。当然,这里的一般方法模仿了您在其他指南中可以找到的方法。您将:


  • 加载音频数据和相关文本转录的训练数据集。
  • 根据数据集中的文本转录创建词汇表。
  • 初始化 wav2vec2 处理器,它将从输入数据中提取特征,并将文本转录转换为标签序列。
  • 对处理后的输入数据进行 wav2vec2 XLS-R 微调。


但是,本指南与其他指南有三个主要区别:


  1. 该指南没有提供太多有关 ML 和 ASR 概念的“内联”讨论。
    • 虽然单个笔记本单元的每个子部分将包含有关特定单元的用途/目的的详细信息,但假设您具有现有的 ML 背景并且您了解基本的 ASR 概念。
  2. 您将构建的 Kaggle Notebook 在顶层单元中组织实用方法。
    • 尽管许多微调笔记本往往具有某种“意识流”式的布局,但我选择将所有实用方法组织在一起。如果您不熟悉 wav2vec2,您可能会发现这种方法令人困惑。但是,重申一下,我在每个单元格的专用子部分中解释每个单元格的用途时会尽力明确。如果您刚刚了解 wav2vec2,快速浏览一下我的 HackerNoon 文章wav2vec2 for Automatic Speech Recognition in Plain English可能会对您有所帮助。
  3. 本指南仅描述微调的步骤。
    • 简介中所述,我选择创建一份单独的配套指南,介绍如何在您将生成的微调 XLS-R 模型上运行推理。这样做是为了防止本指南过长。

先决条件和开始之前

要完成本指南,您需要具备:


  • 现有的Kaggle 帐户。如果您没有现有的 Kaggle 帐户,则需要创建一个。
  • 现有的Weights and Biases 帐户(“WandB”) 。如果您没有现有的 Weights and Biases 帐户,则需要创建一个。
  • WandB API 密钥。如果您没有 WandB API 密钥,请按照此处的步骤操作。
  • 具有中级 Python 知识。
  • 具有使用 Kaggle Notebooks 的中级知识。
  • 对 ML 概念有中级了解。
  • ASR 概念的基本知识。


在开始构建笔记本之前,查看下面的两个小节可能会有所帮助。它们描述了:


  1. 训练数据集。
  2. 训练期间使用的词错误率(“WER”)指标。

训练数据集

引言中所述,XLS-R 模型将针对智利西班牙语进行微调。具体数据集是 Guevara-Rukoz 等人开发的智利西班牙语语音数据集。可在OpenSLR下载。该数据集由两个子数据集组成:(1) 2,636 个智利男性说话者的录音和 (2) 1,738 个智利女性说话者的录音。


每个子数据集包含一个line_index.tsv索引文件。每个索引文件的每一行包含一对音频文件名和相关文件中音频的转录,例如:


 clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente


为了方便起见,我已将智利西班牙语语音数据集上传到 Kaggle。Kaggle 数据集中有一个用于智利男性说话者的录音,还有一个用于智利女性说话者的录音。这些 Kaggle 数据集将添加到您将按照本指南中的步骤构建的 Kaggle Notebook 中。

单词错误率 (WER)

WER 是衡量自动语音识别模型性能的指标之一。WER 提供了一种机制来衡量文本预测与文本参考的接近程度。WER 通过记录 3 种类型的错误来实现这一点:


  • 替换 ( S ):当预测中包含的单词与参考中的类似单词不同时,就会记录替换错误。例如,当预测拼错了参考中的单词时,就会发生这种情况。

  • 删除( D ):当预测包含参考中不存在的单词时,就会记录为删除错误。

  • 插入( I ):当预测不包含参考中存在的单词时,就会记录插入错误。


显然,WER 在单词级别起作用。WER 指标的公式如下:


 WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference


以下是西班牙语的一个简单 WER 示例:


 prediction: "Él está saliendo." reference: "Él está saltando."


表格有助于直观地看到预测中的错误:

文本

单词 1

单词 2

单词 3

预言

埃勒

萨利恩多

参考

埃勒

跳跃


正确的

正确的

代换

预测包含 1 个替换错误、0 个删除错误和 0 个插入错误。因此,此示例的 WER 为:


 WER = 1 + 0 + 0 / 3 = 1/3 = 0.33


显然,单词错误率并不一定能告诉我们存在哪些具体错误。在上面的例子中,WER 识别出单词 3在预测文本中包含错误,但它并没有告诉我们字符ie在预测中是错误的。其他指标,例如字符错误率(“CER”),可用于更精确的错误分析。

构建微调笔记本

您现在可以开始构建微调笔记本了。


  • 步骤 1步骤 2将指导您设置 Kaggle Notebook 环境。
  • 步骤 3指导您构建笔记本本身。它包含 32 个子步骤,代表微调笔记本的 32 个单元。
  • 步骤 4将指导您运行笔记本、监控训练并保存模型。

步骤 1-获取您的 WandB API 密钥

必须配置您的 Kaggle Notebook 才能使用您的 WandB API 密钥将训练运行数据发送到 WandB。为此,您需要复制它。


  1. 登录 WandB: www.wandb.com
  2. 导航至www.wandb.ai/authorize
  3. 复制您的 API 密钥以供下一步使用。

第 2 步 - 设置你的 Kaggle 环境

步骤 2.1 - 创建新的 Kaggle Notebook


  1. 登录 Kaggle。
  2. 创建一个新的 Kaggle Notebook。
  3. 当然,笔记本的名称可以根据需要更改。本指南使用笔记本名称xls-r-300m-chilean-spanish-asr

步骤 2.2 - 设置您的 WandB API 密钥

Kaggle Secret将用于安全存储您的 WandB API 密钥。


  1. 单击 Kaggle Notebook 主菜单上的“附加组件”
  2. 从弹出菜单中选择“秘密”
  3. 标签字段中输入标签WANDB_API_KEY ,并输入您的 WandB API 密钥作为值。
  4. 确保选中WANDB_API_KEY标签字段左侧的附加复选框。
  5. 单击完成

步骤 2.3 - 添加训练数据集

智利西班牙语语音数据集已作为 2 个不同的数据集上传至 Kaggle:


将这两个数据集都添加到您的 Kaggle Notebook。

步骤 3 - 构建微调笔记本

以下 32 个子步骤按顺序构建微调笔记本的 32 个单元。

步骤 3.1 - CELL 1:安装包

微调笔记本的第一个单元格安装依赖项。将第一个单元格设置为:


 ### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer


  • 第一行将torchaudio包升级到最新版本。torchaudio torchaudio用于加载音频文件并重新采样音频数据。
  • 第二行安装jiwer包,该包是稍后使用 HuggingFace Datasetsload_metric方法所必需的。

步骤 3.2 - CELL 2:导入 Python 包

第二个单元格导入所需的 Python 包。将第二个单元格设置为:


 ### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer


  • 您可能已经熟悉了其中的大部分包。在后续单元构建过程中,我们将解释它们在笔记本中的用法。
  • 值得一提的是,HuggingFace transformers库和相关的Wav2Vec2*类提供了用于微调的功能的支柱。

步骤 3.3 - 单元 3:加载 WER 指标

第三个单元格导入 HuggingFace WER 评估指标。将第三个单元格设置为:


 ### CELL 3: Load WER metric ### wer_metric = load_metric("wer")


  • 如前所述,WER 将用于衡量模型在评估/保留数据上的性能。

步骤 3.4 - 单元 4:登录 WandB

第四个单元格检索您在步骤 2.2中设置的WANDB_API_KEY密钥。将第四个单元格设置为:


 ### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)


  • API 密钥用于配置 Kaggle Notebook,以便将训练运行数据发送到 WandB。

步骤 3.5 - 单元格 5:设置常数

第五个单元格设置将在整个笔记本中使用的常量。将第五个单元格设置为:


 ### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800


  • 笔记本未在此单元格中列出所有可以想到的常数。一些可以用常数表示的值已保留在行内。
  • 上面许多常量的用途应该是不言而喻的。对于那些不言而喻的,它们的用途将在以下子步骤中解释。

步骤 3.6 - CELL 6:读取索引文件、清理文本和创建词汇表的实用方法

第六个单元格定义了用于读取数据集索引文件(请参阅上面的“训练数据集”子部分)以及清理转录文本和创建词汇表的实用方法。将第六个单元格设置为:


 ### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list


  • read_index_file_data方法读取line_index.tsv数据集索引文件并生成包含音频文件名和转录数据的列表列表,例如:


 [ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]


  • truncate_training_dataset方法使用步骤 3.5中设置的NUM_LOAD_FROM_EACH_SET常量截断列表索引文件数据。具体来说, NUM_LOAD_FROM_EACH_SET常量用于指定应从每个数据集加载的音频样本数量。就本指南而言,该数字设置为1600 ,这意味着最终将加载总共3200音频样本。要加载所有样本,请将NUM_LOAD_FROM_EACH_SET设置为字符串值all
  • clean_text方法用于删除每个文本转录中步骤 3.5中分配给SPECIAL_CHARS的正则表达式所指定的字符。这些字符(包括标点符号)可以被删除,因为它们在训练模型以学习音频特征和文本转录之间的映射时不提供任何语义价值。
  • create_vocab方法根据干净的文本转录创建词汇表。简单来说,它从一组已清理的文本转录中提取所有唯一字符。您将在步骤 3.14中看到生成的词汇表的示例。

步骤 3.7 - 单元 7:用于加载和重采样音频数据的实用方法

第七个单元格定义使用torchaudio加载和重新采样音频数据的实用方法。将第七个单元格设置为:


 ### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]


  • read_audio_data方法加载指定的音频文件并返回音频数据的torch.Tensor多维矩阵以及音频的采样率。训练数据中的所有音频文件的采样率均为48000 Hz。此“原始”采样率由步骤 3.5中的常量ORIG_SAMPLING_RATE捕获。
  • resample方法用于将音频数据从采样率48000下采样到16000 。 wav2vec2 是在以16000 Hz 采样的音频上进行预训练的。 因此,用于微调的任何音频都必须具有相同的采样率。 在这种情况下,音频示例必须从48000 Hz 下采样到16000 Hz。 16000 Hz 由步骤 3.5中的常量TGT_SAMPLING_RATE捕获。

步骤 3.8 - 单元格 8:准备训练数据的实用方法

第八个单元格定义处理音频和转录数据的实用方法。将第八个单元格设置为:


 ### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding


  • process_speech_audio方法从提供的训练样本返回输入值。
  • process_target_text方法将每个文本转录编码为标签列表 - 即指向词汇表中字符的索引列表。您将在步骤 3.15中看到示例编码。

步骤 3.9 - 单元格 9:计算单词错误率的实用方法

第九个单元格是最后一个实用方法单元格,包含计算参考转录和预测转录之间的字错误率的方法。将第九个单元格设置为:


 ### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}

步骤 3.10 - 单元格 10:读取训练数据

第十个单元格使用步骤 3.6中定义的read_index_file_data方法读取男性说话者录音和女性说话者录音的训练数据索引文件。将第十个单元格设置为:


 ### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")


  • 可以看出,此时训练数据被管理在两个特定性别的列表中。截断后,数据将在步骤 3.12中合并。

步骤 3.11 - 单元格 11:截断训练数据

第 11 个单元格使用步骤 3.6中定义的truncate_training_dataset方法截断训练数据列表。将第 11 个单元格设置为:


 ### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)


  • 提醒一下,步骤 3.5中设置的NUM_LOAD_FROM_EACH_SET常量定义了每个数据集中要保留的样本数量。本指南中将该常量设置为1600总共3200样本。

步骤 3.12 - 单元格 12:合并训练样本数据

第十二个单元格合并截断的训练数据列表。将第十二个单元格设置为:


 ### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl

步骤 3.13 - 单元格 13:清理转录测试

第十三个单元格对每个训练数据样本进行迭代,并使用步骤 3.6中定义的clean_text方法清理相关的转录文本。将第十三个单元格设置为:


 for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])

步骤 3.14 - 单元格 14:创建词汇表

第 14 个单元格使用上一步中清理好的转录和步骤 3.6中定义的create_vocab方法创建词汇表。将第 14 个单元格设置为:


 ### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}


  • 词汇表以字典的形式存储,以字符为键,以词汇索引为值。

  • 您可以打印vocab_dict ,它会产生以下输出:


 {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}

步骤 3.15 - 单元格 15:向词汇表添加单词分隔符

第十五个单元格将单词分隔符|添加到词汇表中。将第十五个单元格设置为:


 ### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)


  • 单词分隔符用于将文本转录标记为标签列表。具体来说,它用于定义单词的结尾,并在初始化Wav2Vec2CTCTokenizer类时使用,如步骤 3.17中所示。

  • 例如,以下列表使用步骤 3.14中的词汇对no te entiendo nada进行编码:


 # Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}


  • 一个自然而然会出现的问题是:“为什么需要定义单词分隔符?”例如,书面英语和西班牙语的单词结尾用空格标记,因此使用空格字符作为单词分隔符应该是一件简单的事情。请记住,英语和西班牙语只是数千种语言中的两种;并不是所有书面语言都使用空格来标记单词边界。

步骤 3.16 - 单元格 16:导出词汇

第 16 个单元格将词汇表转储到文件中。将第 16 个单元格设置为:


 ### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)


  • 词汇文件将在下一步(步骤 3.17)中用于初始化Wav2Vec2CTCTokenizer类。

步骤 3.17 - CELL 17:初始化标记器

第 17 个单元格初始化Wav2Vec2CTCTokenizer的实例。将第 17 个单元格设置为:


 ### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )


  • 标记器用于对文本转录进行编码并将标签列表解码回文本。

  • 请注意, tokenizer初始化时将[UNK]分配给unk_token并将[PAD]分配给pad_token ,前者用于表示文本转录中的未知标记,后者用于在创建具有不同长度的转录批次时填充转录。标记器将把这两个值添加到词汇表中。

  • 在此步骤中初始化标记器还将向词汇表中添加两个额外的标记,即<s>/</s> ,分别用于划分句子的开始和结束。

  • |在此步骤中明确分配给word_delimiter_token ,以反映管道符号将用于划分单词的结尾,以符合我们在步骤 3.15中将字符添加到词汇表中的规定。 |符号是word_delimiter_token的默认值。因此,不需要明确设置,但这样做是为了清晰起见。

  • word_delimiter_token类似,明确为replace_word_delimiter_char分配了一个空格,这反映了管道符号|将用于替换文本转录中的空格字符。空格是replace_word_delimiter_char的默认值。因此,它也不需要明确设置,但这样做是为了清晰起见。

  • 您可以通过调用tokenizer上的get_vocab()方法来打印完整的 tokenizer 词汇表。


 vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}

步骤 3.18 - 单元格 18:初始化特征提取器

第 18 个单元格初始化Wav2Vec2FeatureExtractor的实例。将第 18 个单元格设置为:


 ### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )


  • 特征提取器用于从输入数据中提取特征,在本用例中,输入数据当然是音频数据。您将在步骤 3.20中为每个训练数据样本加载音频数据。
  • 传递给Wav2Vec2FeatureExtractor初始化程序的参数值都是默认值,但return_attention_mask除外,其默认值为False 。为了清晰起见,显示/传递了默认值。
  • feature_size参数指定输入特征(即音频数据特征)的维度大小。此参数的默认值为1
  • sampling_rate告诉特征提取器音频数据数字化的采样率。如步骤 3.7中所述,wav2vec2 是在16000 Hz 采样的音频上进行预训练的,因此16000是此参数的默认值。
  • padding_value参数指定在对不同长度的音频样本进行批处理时填充音频数据时使用的值。默认值为0.0
  • do_normalize用于指定是否应将输入数据转换为标准正态分布。默认值为True 。Wav2Vec2FeatureExtractor 类文档指出“[标准化] 可以帮助显Wav2Vec2FeatureExtractor提高某些模型的性能。”
  • return_attention_mask参数指定是否应传递注意掩码。对于此用例,该值设置为True

步骤 3.19 - 单元 19:初始化处理器

第十九个单元格初始化Wav2Vec2Processor的实例。将第十九个单元格设置为:


 ### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)


  • Wav2Vec2Processor类将步骤 3.17步骤 3.18中的tokenizerfeature_extractor分别组合成一个处理器。

  • 请注意,可以通过调用Wav2Vec2Processor类实例上的save_pretrained方法来保存处理器配置。


 processor.save_pretrained(OUTPUT_DIR_PATH)

步骤 3.20 - 单元格 20:加载音频数据

第二十个单元格加载all_training_samples列表中指定的每个音频文件。将第二十个单元格设置为:


 ### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })


  • 音频数据以torch.Tensor形式返回,并以字典列表的形式存储在all_input_data中。每个字典包含特定样本的音频数据以及音频的文本转录。
  • 请注意, read_audio_data方法也会返回音频数据的采样率。由于我们知道此用例中所有音频文件的采样率均为48000 Hz,因此在此步骤中忽略采样率。

步骤 3.21 - 单元格 21:将all_input_data转换为 Pandas DataFrame

第 21 个单元格将all_input_data列表转换为 Pandas DataFrame,以便于操作数据。将第 21 个单元格设置为:


 ### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)

步骤 3.22 - 单元格 22:处理音频数据和文本转录

第 22 个单元格使用步骤 3.19中初始化的processor从每个音频数据样本中提取特征,并将每个文本转录编码为标签列表。将第 22 个单元格设置为:


 ### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))

步骤 3.23 - 单元格 23:将输入数据拆分为训练数据集和验证数据集

第 23 个单元格使用步骤 3.5中的SPLIT_PCT常量将all_input_data_df DataFrame 拆分为训练和评估(验证)数据集。将第 23 个单元格设置为:


 ### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]


  • 本指南中的SPLIT_PCT值为0.10这意味着所有输入数据的 10% 将用于评估,90% 的数据将用于训练/微调。
  • 由于训练样本共有 3200 个,因此将使用 320 个样本进行评估,其余 2880 个样本用于微调模型。

步骤 3.24 - CELL 24:将训练和验证数据集转换为Dataset对象

第 24 个单元格将train_data_dfvalid_data_df DataFrames 转换为Dataset对象。将第 24 个单元格设置为:


 ### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)


  • Dataset对象被 HuggingFace Trainer类实例使用,正如您将在步骤 3.30中看到的那样。

  • 这些对象包含有关数据集以及数据集本身的元数据。

  • 您可以打印train_datavalid_data来查看两个Dataset对象的元数据。


 print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })

步骤 3.25 - 单元格 25:初始化预训练模型

第 25 个单元格初始化预训练的 XLS-R (0.3) 模型。将第 25 个单元格设置为:


 ### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )


  • Wav2Vec2ForCTC上调用的from_pretrained方法指定我们要为指定模型加载预训练权重。
  • MODEL常数在步骤 3.5中指定,并设置为facebook/wav2vec2-xls-r-300m反映 XLS-R (0.3) 模型。
  • ctc_loss_reduction参数指定应用于连接时间分类 (“CTC”) 损失函数输出的缩减类型。CTC 损失用于计算连续输入(本例中为音频数据)与目标序列(本例中为文本转录)之间的损失。通过将值设置为mean ,一批输入的输出损失将除以目标长度。然后计算该批次的平均值并将缩减应用于损失值。
  • pad_token_id指定批处理时用于填充的 token。它设置为步骤 3.17中初始化 tokenizer 时设置的[PAD] id。
  • vocab_size参数定义模型的词汇量。它是步骤 3.17中初始化 tokenizer 后的词汇量,反映了网络前向部分的输出层节点数。

步骤 3.26 - 单元格 26:冻结特征提取器权重

第 26 个单元格冻结特征提取器的预训练权重。将第 26 个单元格设置为:


 ### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()

步骤 3.27 - 单元格 27:设置训练参数

第 27 个单元格初始化将传递给Trainer实例的训练参数。将第 27 个单元格设置为:


 ### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )


  • TrainingArguments类接受100 多个参数
  • save_safetensors参数为False时指定应将微调模型保存到pickle文件中,而不是使用safetensors格式。
  • group_by_length参数为True时表示将长度大致相同的样本分组在一起。这可以最大限度地减少填充并提高训练效率。
  • per_device_train_batch_size设置每个训练小批次的样本数量。此参数通过步骤 3.5中分配的TRAIN_BATCH_SIZE常量设置为18这意味着每个时期 160 步。
  • per_device_eval_batch_size设置每个评估(保留)小批次的样本数量。此参数通过步骤 3.5中分配的EVAL_BATCH_SIZE常量设置为10
  • num_train_epochs设置训练周期数。此参数通过步骤 3.5中指定的TRAIN_EPOCHS常量设置为30这意味着训练期间总共 4,800 步。
  • gradient_checkpointing参数为True时,它有助于通过检查点梯度计算来节省内存,但会导致向后传递速度变慢。
  • evaluation_strategy参数设置为steps时,意味着将在训练期间按照参数eval_steps指定的间隔执行并记录评估。
  • 当将logging_strategy参数设置为steps时,意味着将按照参数logging_steps指定的间隔记录训练运行统计数据。
  • save_strategy参数设置为steps时,意味着将按照参数save_steps指定的间隔保存微调模型的检查点。
  • eval_steps设置保留数据评估之间的步数。此参数通过步骤 3.5中分配的EVAL_STEPS常量设置为100
  • save_steps设置保存微调模型检查点的步骤数。此参数通过步骤 3.5中分配的SAVE_STEPS常量设置为3200
  • logging_steps设置训练运行统计日志之间的步数。此参数通过步骤 3.5中分配的LOGGING_STEPS常量设置为100
  • learning_rate参数设置初始学习率。此参数通过步骤 3.5中指定的LEARNING_RATE常量设置为1e-4
  • warmup_steps参数设置将学习率从 0 线性预热到learning_rate设置的值的步数。此参数通过步骤 3.5中分配的WARMUP_STEPS常量设置为800

步骤 3.28 - 单元格 28:定义数据收集器逻辑

第 28 个单元格定义动态填充输入和目标序列的逻辑。将第 28 个单元格设置为:


 ### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch


  • 训练和评估输入标签对以小批量的形式传递给Trainer实例,该实例将在步骤 3.30中立即初始化。由于每个小批量中的输入序列和标签序列的长度各不相同,因此必须填充一些序列,以使它们的长度都相同。
  • DataCollatorCTCWithPadding类动态填充小批量数据。padding padding设置为True时,指定较短的音频输入特征序列和标签序列应与小批量中的最长序列具有相同的长度。
  • 步骤 3.18初始化特征提取器时,音频输入特征用值0.0填充。
  • 标签输入首先用步骤 3.17中初始化标记器时设置的填充值进行填充。这些值被替换为-100 ,以便在计算 WER 指标时忽略这些标签。

步骤 3.29 - CELL 29:初始化数据收集器实例

第 29 个单元格初始化上一步中定义的数据整理器实例。将第 29 个单元格设置为:


 ### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)

步骤 3.30 - 单元 30:初始化训练器

第三十个单元格初始化Trainer类的实例。将第三十个单元格设置为:


 ### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )


  • 如图所示, Trainer类的初始化方法如下:
    • 步骤 3.25中初始化的预训练model
    • 数据收集器在步骤 3.29中初始化。
    • 步骤 3.27中初始化的训练参数。
    • 步骤3.9中定义的WER评估方法。
    • 来自步骤 3.24 的train_data Dataset对象。
    • 来自步骤 3.24 的valid_data Dataset对象。
  • tokenizer参数被分配给processor.feature_extractor ,并与data_collator一起自动将输入填充到每个小批量的最大长度输入。

步骤 3.31 - 单元格 31:微调模型

第三十一个单元格调用Trainer类实例上的train方法来微调模型。将第三十一个单元格设置为:


 ### CELL 31: Finetune the model ### trainer.train()

步骤 3.32 - CELL 32:保存微调后的模型

第 32 个单元格是最后一个笔记本单元格。它通过调用Trainer实例上的save_model方法保存微调后的模型。将第 32 个单元格设置为:


 ### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)

步骤 4-训练和保存模型

步骤4.1-训练模型

现在笔记本的所有单元都已构建完毕,是时候开始进行微调了。


  1. 将 Kaggle Notebook 设置为使用NVIDIA GPU P100加速器运行。

  2. 在 Kaggle 上提交笔记本。

  3. 通过登录您的 WandB 帐户并找到相关运行来监控训练运行数据。


使用 NVIDIA GPU P100 加速器,超过 30 个 epoch 的训练大约需要 5 个小时。训练结束时,保留数据的 WER 应该会降至约 0.15。这不算是最先进的结果,但经过微调的模型对于许多应用来说仍然足够有用。

步骤 4.2 - 保存模型

经过微调的模型将输出到步骤 3.5中指定的常量OUTPUT_DIR_PATH指定的 Kaggle 目录。模型输出应包括以下文件:


 pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin


这些文件可以下载到本地。此外,您可以使用模型文件创建新的Kaggle 模型。Kaggle模型将与配套的推理指南一起使用,对经过微调的模型进行推理。


  1. 登录你的 Kaggle 账户。单击模型>新模型
  2. 模型标题字段中为您微调的模型添加标题。
  3. 单击创建模型
  4. 单击转到模型详细信息页面
  5. 单击“模型变体”下的“添加新变体”
  6. 框架选择菜单中选择变压器
  7. 点击添加新变体
  8. 将微调后的模型文件拖放到“上传数据”窗口中。或者,单击“浏览文件”按钮打开文件资源管理器窗口并选择微调后的模型文件。
  9. 将文件上传到 Kaggle 后,单击“创建”以创建Kaggle 模型

结论

恭喜您微调了 wav2vec2 XLS-R!请记住,您可以使用这些常规步骤对您想要的其他语言的模型进行微调。在本指南中生成的微调模型上运行推理相当简单。推理步骤将在本指南的单独配套指南中概述。请搜索我的 HackerNoon 用户名以查找配套指南。