Meta AI 于 2021 年底推出了wav2vec2 XLS-R(“XLS-R”)。 XLS-R 是一种用于跨语言语音表示学习的机器学习(“ML”)模型;它基于 128 种语言的 40 多万小时公开语音音频进行训练。该模型发布后,超越了 Meta AI 的XLSR-53跨语言模型,后者基于 53 种语言的约 5 万小时语音音频进行训练。
本指南介绍了使用Kaggle Notebook对 XLS-R 进行自动语音识别 (ASR) 微调的步骤。该模型将针对智利西班牙语进行微调,但您可以按照一般步骤对您想要的不同语言的 XLS-R 进行微调。
将在配套教程中介绍如何在经过微调的模型上运行推理,这使得本指南成为两部分中的第一部分。由于本微调指南有点长,我决定创建单独的推理专用指南。
假设您具有 ML 背景并了解基本的 ASR 概念。初学者可能难以理解构建步骤。
2020 年推出的原始 wav2vec2 模型是在 960 小时的Librispeech数据集语音音频和约 53,200 小时的LibriVox数据集语音音频上进行预训练的。发布时,有两种模型大小可供选择:具有 9500 万个参数的BASE模型和具有 3.17 亿个参数的LARGE模型。
另一方面,XLS-R 已对来自 5 个数据集的多语言语音音频进行了预训练:
XLS-R 共有 3 个模型: XLS-R (0.3B)具有 3 亿个参数、 XLS-R (1B)具有 10 亿个参数, XLS-R (2B)具有 20 亿个参数。本指南将使用 XLS-R (0.3B) 模型。
关于如何微调wav2vev2模型,有很多很棒的文章,也许这篇文章可以说是某种“黄金标准”。当然,这里的一般方法模仿了您在其他指南中可以找到的方法。您将:
但是,本指南与其他指南有三个主要区别:
要完成本指南,您需要具备:
在开始构建笔记本之前,查看下面的两个小节可能会有所帮助。它们描述了:
如引言中所述,XLS-R 模型将针对智利西班牙语进行微调。具体数据集是 Guevara-Rukoz 等人开发的智利西班牙语语音数据集。可在OpenSLR下载。该数据集由两个子数据集组成:(1) 2,636 个智利男性说话者的录音和 (2) 1,738 个智利女性说话者的录音。
每个子数据集包含一个line_index.tsv
索引文件。每个索引文件的每一行包含一对音频文件名和相关文件中音频的转录,例如:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
为了方便起见,我已将智利西班牙语语音数据集上传到 Kaggle。Kaggle 数据集中有一个用于智利男性说话者的录音,还有一个用于智利女性说话者的录音。这些 Kaggle 数据集将添加到您将按照本指南中的步骤构建的 Kaggle Notebook 中。
WER 是衡量自动语音识别模型性能的指标之一。WER 提供了一种机制来衡量文本预测与文本参考的接近程度。WER 通过记录 3 种类型的错误来实现这一点:
替换 ( S
):当预测中包含的单词与参考中的类似单词不同时,就会记录替换错误。例如,当预测拼错了参考中的单词时,就会发生这种情况。
删除( D
):当预测包含参考中不存在的单词时,就会记录为删除错误。
插入( I
):当预测不包含参考中存在的单词时,就会记录插入错误。
显然,WER 在单词级别起作用。WER 指标的公式如下:
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
以下是西班牙语的一个简单 WER 示例:
prediction: "Él está saliendo." reference: "Él está saltando."
表格有助于直观地看到预测中的错误:
文本 | 单词 1 | 单词 2 | 单词 3 |
---|---|---|---|
预言 | 埃勒 | 是 | 萨利恩多 |
参考 | 埃勒 | 是 | 跳跃 |
| 正确的 | 正确的 | 代换 |
预测包含 1 个替换错误、0 个删除错误和 0 个插入错误。因此,此示例的 WER 为:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
显然,单词错误率并不一定能告诉我们存在哪些具体错误。在上面的例子中,WER 识别出单词 3在预测文本中包含错误,但它并没有告诉我们字符i和e在预测中是错误的。其他指标,例如字符错误率(“CER”),可用于更精确的错误分析。
您现在可以开始构建微调笔记本了。
必须配置您的 Kaggle Notebook 才能使用您的 WandB API 密钥将训练运行数据发送到 WandB。为此,您需要复制它。
www.wandb.com
。www.wandb.ai/authorize
。
xls-r-300m-chilean-spanish-asr
。Kaggle Secret将用于安全存储您的 WandB API 密钥。
WANDB_API_KEY
,并输入您的 WandB API 密钥作为值。WANDB_API_KEY
标签字段左侧的附加复选框。智利西班牙语语音数据集已作为 2 个不同的数据集上传至 Kaggle:
将这两个数据集都添加到您的 Kaggle Notebook。
以下 32 个子步骤按顺序构建微调笔记本的 32 个单元。
微调笔记本的第一个单元格安装依赖项。将第一个单元格设置为:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
包升级到最新版本。torchaudio torchaudio
用于加载音频文件并重新采样音频数据。jiwer
包,该包是稍后使用 HuggingFace Datasets
库load_metric
方法所必需的。第二个单元格导入所需的 Python 包。将第二个单元格设置为:
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
库和相关的Wav2Vec2*
类提供了用于微调的功能的支柱。第三个单元格导入 HuggingFace WER 评估指标。将第三个单元格设置为:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
第四个单元格检索您在步骤 2.2中设置的WANDB_API_KEY
密钥。将第四个单元格设置为:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
第五个单元格设置将在整个笔记本中使用的常量。将第五个单元格设置为:
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
第六个单元格定义了用于读取数据集索引文件(请参阅上面的“训练数据集”子部分)以及清理转录文本和创建词汇表的实用方法。将第六个单元格设置为:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
read_index_file_data
方法读取line_index.tsv
数据集索引文件并生成包含音频文件名和转录数据的列表列表,例如:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
方法使用步骤 3.5中设置的NUM_LOAD_FROM_EACH_SET
常量截断列表索引文件数据。具体来说, NUM_LOAD_FROM_EACH_SET
常量用于指定应从每个数据集加载的音频样本数量。就本指南而言,该数字设置为1600
,这意味着最终将加载总共3200
音频样本。要加载所有样本,请将NUM_LOAD_FROM_EACH_SET
设置为字符串值all
。clean_text
方法用于删除每个文本转录中步骤 3.5中分配给SPECIAL_CHARS
的正则表达式所指定的字符。这些字符(包括标点符号)可以被删除,因为它们在训练模型以学习音频特征和文本转录之间的映射时不提供任何语义价值。create_vocab
方法根据干净的文本转录创建词汇表。简单来说,它从一组已清理的文本转录中提取所有唯一字符。您将在步骤 3.14中看到生成的词汇表的示例。第七个单元格定义使用torchaudio
加载和重新采样音频数据的实用方法。将第七个单元格设置为:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
方法加载指定的音频文件并返回音频数据的torch.Tensor
多维矩阵以及音频的采样率。训练数据中的所有音频文件的采样率均为48000
Hz。此“原始”采样率由步骤 3.5中的常量ORIG_SAMPLING_RATE
捕获。resample
方法用于将音频数据从采样率48000
下采样到16000
。 wav2vec2 是在以16000
Hz 采样的音频上进行预训练的。 因此,用于微调的任何音频都必须具有相同的采样率。 在这种情况下,音频示例必须从48000
Hz 下采样到16000
Hz。 16000
Hz 由步骤 3.5中的常量TGT_SAMPLING_RATE
捕获。第八个单元格定义处理音频和转录数据的实用方法。将第八个单元格设置为:
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
方法从提供的训练样本返回输入值。process_target_text
方法将每个文本转录编码为标签列表 - 即指向词汇表中字符的索引列表。您将在步骤 3.15中看到示例编码。第九个单元格是最后一个实用方法单元格,包含计算参考转录和预测转录之间的字错误率的方法。将第九个单元格设置为:
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
第十个单元格使用步骤 3.6中定义的read_index_file_data
方法读取男性说话者录音和女性说话者录音的训练数据索引文件。将第十个单元格设置为:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
第 11 个单元格使用步骤 3.6中定义的truncate_training_dataset
方法截断训练数据列表。将第 11 个单元格设置为:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
常量定义了每个数据集中要保留的样本数量。本指南中将该常量设置为1600
总共3200
样本。第十二个单元格合并截断的训练数据列表。将第十二个单元格设置为:
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
第十三个单元格对每个训练数据样本进行迭代,并使用步骤 3.6中定义的clean_text
方法清理相关的转录文本。将第十三个单元格设置为:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
第 14 个单元格使用上一步中清理好的转录和步骤 3.6中定义的create_vocab
方法创建词汇表。将第 14 个单元格设置为:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
词汇表以字典的形式存储,以字符为键,以词汇索引为值。
您可以打印vocab_dict
,它会产生以下输出:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
第十五个单元格将单词分隔符|
添加到词汇表中。将第十五个单元格设置为:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
单词分隔符用于将文本转录标记为标签列表。具体来说,它用于定义单词的结尾,并在初始化Wav2Vec2CTCTokenizer
类时使用,如步骤 3.17中所示。
例如,以下列表使用步骤 3.14中的词汇对no te entiendo nada
进行编码:
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
第 16 个单元格将词汇表转储到文件中。将第 16 个单元格设置为:
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
类。第 17 个单元格初始化Wav2Vec2CTCTokenizer
的实例。将第 17 个单元格设置为:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
标记器用于对文本转录进行编码并将标签列表解码回文本。
请注意, tokenizer
初始化时将[UNK]
分配给unk_token
并将[PAD]
分配给pad_token
,前者用于表示文本转录中的未知标记,后者用于在创建具有不同长度的转录批次时填充转录。标记器将把这两个值添加到词汇表中。
在此步骤中初始化标记器还将向词汇表中添加两个额外的标记,即<s>
和/</s>
,分别用于划分句子的开始和结束。
|
在此步骤中明确分配给word_delimiter_token
,以反映管道符号将用于划分单词的结尾,以符合我们在步骤 3.15中将字符添加到词汇表中的规定。 |
符号是word_delimiter_token
的默认值。因此,不需要明确设置,但这样做是为了清晰起见。
与word_delimiter_token
类似,明确为replace_word_delimiter_char
分配了一个空格,这反映了管道符号|
将用于替换文本转录中的空格字符。空格是replace_word_delimiter_char
的默认值。因此,它也不需要明确设置,但这样做是为了清晰起见。
您可以通过调用tokenizer
上的get_vocab()
方法来打印完整的 tokenizer 词汇表。
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
第 18 个单元格初始化Wav2Vec2FeatureExtractor
的实例。将第 18 个单元格设置为:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
初始化程序的参数值都是默认值,但return_attention_mask
除外,其默认值为False
。为了清晰起见,显示/传递了默认值。feature_size
参数指定输入特征(即音频数据特征)的维度大小。此参数的默认值为1
。sampling_rate
告诉特征提取器音频数据数字化的采样率。如步骤 3.7中所述,wav2vec2 是在16000
Hz 采样的音频上进行预训练的,因此16000
是此参数的默认值。padding_value
参数指定在对不同长度的音频样本进行批处理时填充音频数据时使用的值。默认值为0.0
。do_normalize
用于指定是否应将输入数据转换为标准正态分布。默认值为True
。Wav2Vec2FeatureExtractor 类文档指出“[标准化] 可以帮助显Wav2Vec2FeatureExtractor
提高某些模型的性能。”return_attention_mask
参数指定是否应传递注意掩码。对于此用例,该值设置为True
。第十九个单元格初始化Wav2Vec2Processor
的实例。将第十九个单元格设置为:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Wav2Vec2Processor
类将步骤 3.17和步骤 3.18中的tokenizer
和feature_extractor
分别组合成一个处理器。
请注意,可以通过调用Wav2Vec2Processor
类实例上的save_pretrained
方法来保存处理器配置。
processor.save_pretrained(OUTPUT_DIR_PATH)
第二十个单元格加载all_training_samples
列表中指定的每个音频文件。将第二十个单元格设置为:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
形式返回,并以字典列表的形式存储在all_input_data
中。每个字典包含特定样本的音频数据以及音频的文本转录。read_audio_data
方法也会返回音频数据的采样率。由于我们知道此用例中所有音频文件的采样率均为48000
Hz,因此在此步骤中忽略采样率。all_input_data
转换为 Pandas DataFrame第 21 个单元格将all_input_data
列表转换为 Pandas DataFrame,以便于操作数据。将第 21 个单元格设置为:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
第 22 个单元格使用步骤 3.19中初始化的processor
从每个音频数据样本中提取特征,并将每个文本转录编码为标签列表。将第 22 个单元格设置为:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
第 23 个单元格使用步骤 3.5中的SPLIT_PCT
常量将all_input_data_df
DataFrame 拆分为训练和评估(验证)数据集。将第 23 个单元格设置为:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
值为0.10
这意味着所有输入数据的 10% 将用于评估,90% 的数据将用于训练/微调。Dataset
对象第 24 个单元格将train_data_df
和valid_data_df
DataFrames 转换为Dataset
对象。将第 24 个单元格设置为:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Dataset
对象被 HuggingFace Trainer
类实例使用,正如您将在步骤 3.30中看到的那样。
这些对象包含有关数据集以及数据集本身的元数据。
您可以打印train_data
和valid_data
来查看两个Dataset
对象的元数据。
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
第 25 个单元格初始化预训练的 XLS-R (0.3) 模型。将第 25 个单元格设置为:
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
Wav2Vec2ForCTC
上调用的from_pretrained
方法指定我们要为指定模型加载预训练权重。MODEL
常数在步骤 3.5中指定,并设置为facebook/wav2vec2-xls-r-300m
反映 XLS-R (0.3) 模型。ctc_loss_reduction
参数指定应用于连接时间分类 (“CTC”) 损失函数输出的缩减类型。CTC 损失用于计算连续输入(本例中为音频数据)与目标序列(本例中为文本转录)之间的损失。通过将值设置为mean
,一批输入的输出损失将除以目标长度。然后计算该批次的平均值并将缩减应用于损失值。pad_token_id
指定批处理时用于填充的 token。它设置为步骤 3.17中初始化 tokenizer 时设置的[PAD]
id。vocab_size
参数定义模型的词汇量。它是步骤 3.17中初始化 tokenizer 后的词汇量,反映了网络前向部分的输出层节点数。第 26 个单元格冻结特征提取器的预训练权重。将第 26 个单元格设置为:
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
第 27 个单元格初始化将传递给Trainer
实例的训练参数。将第 27 个单元格设置为:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
类接受100 多个参数。save_safetensors
参数为False
时指定应将微调模型保存到pickle
文件中,而不是使用safetensors
格式。group_by_length
参数为True
时表示将长度大致相同的样本分组在一起。这可以最大限度地减少填充并提高训练效率。per_device_train_batch_size
设置每个训练小批次的样本数量。此参数通过步骤 3.5中分配的TRAIN_BATCH_SIZE
常量设置为18
这意味着每个时期 160 步。per_device_eval_batch_size
设置每个评估(保留)小批次的样本数量。此参数通过步骤 3.5中分配的EVAL_BATCH_SIZE
常量设置为10
。num_train_epochs
设置训练周期数。此参数通过步骤 3.5中指定的TRAIN_EPOCHS
常量设置为30
这意味着训练期间总共 4,800 步。gradient_checkpointing
参数为True
时,它有助于通过检查点梯度计算来节省内存,但会导致向后传递速度变慢。evaluation_strategy
参数设置为steps
时,意味着将在训练期间按照参数eval_steps
指定的间隔执行并记录评估。logging_strategy
参数设置为steps
时,意味着将按照参数logging_steps
指定的间隔记录训练运行统计数据。save_strategy
参数设置为steps
时,意味着将按照参数save_steps
指定的间隔保存微调模型的检查点。eval_steps
设置保留数据评估之间的步数。此参数通过步骤 3.5中分配的EVAL_STEPS
常量设置为100
。save_steps
设置保存微调模型检查点的步骤数。此参数通过步骤 3.5中分配的SAVE_STEPS
常量设置为3200
。logging_steps
设置训练运行统计日志之间的步数。此参数通过步骤 3.5中分配的LOGGING_STEPS
常量设置为100
。learning_rate
参数设置初始学习率。此参数通过步骤 3.5中指定的LEARNING_RATE
常量设置为1e-4
。warmup_steps
参数设置将学习率从 0 线性预热到learning_rate
设置的值的步数。此参数通过步骤 3.5中分配的WARMUP_STEPS
常量设置为800
。第 28 个单元格定义动态填充输入和目标序列的逻辑。将第 28 个单元格设置为:
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
实例,该实例将在步骤 3.30中立即初始化。由于每个小批量中的输入序列和标签序列的长度各不相同,因此必须填充一些序列,以使它们的长度都相同。DataCollatorCTCWithPadding
类动态填充小批量数据。padding padding
设置为True
时,指定较短的音频输入特征序列和标签序列应与小批量中的最长序列具有相同的长度。0.0
填充。-100
,以便在计算 WER 指标时忽略这些标签。第 29 个单元格初始化上一步中定义的数据整理器实例。将第 29 个单元格设置为:
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
第三十个单元格初始化Trainer
类的实例。将第三十个单元格设置为:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
类的初始化方法如下:model
。train_data
Dataset
对象。valid_data
Dataset
对象。tokenizer
参数被分配给processor.feature_extractor
,并与data_collator
一起自动将输入填充到每个小批量的最大长度输入。第三十一个单元格调用Trainer
类实例上的train
方法来微调模型。将第三十一个单元格设置为:
### CELL 31: Finetune the model ### trainer.train()
第 32 个单元格是最后一个笔记本单元格。它通过调用Trainer
实例上的save_model
方法保存微调后的模型。将第 32 个单元格设置为:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
现在笔记本的所有单元都已构建完毕,是时候开始进行微调了。
将 Kaggle Notebook 设置为使用NVIDIA GPU P100加速器运行。
在 Kaggle 上提交笔记本。
通过登录您的 WandB 帐户并找到相关运行来监控训练运行数据。
使用 NVIDIA GPU P100 加速器,超过 30 个 epoch 的训练大约需要 5 个小时。训练结束时,保留数据的 WER 应该会降至约 0.15。这不算是最先进的结果,但经过微调的模型对于许多应用来说仍然足够有用。
经过微调的模型将输出到步骤 3.5中指定的常量OUTPUT_DIR_PATH
指定的 Kaggle 目录。模型输出应包括以下文件:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
这些文件可以下载到本地。此外,您可以使用模型文件创建新的Kaggle 模型。Kaggle模型将与配套的推理指南一起使用,对经过微调的模型进行推理。
恭喜您微调了 wav2vec2 XLS-R!请记住,您可以使用这些常规步骤对您想要的其他语言的模型进行微调。在本指南中生成的微调模型上运行推理相当简单。推理步骤将在本指南的单独配套指南中概述。请搜索我的 HackerNoon 用户名以查找配套指南。