Meta AI は、2021 年末にwav2vec2 XLS-R (以下「XLS-R」)を導入しました。XLS-R は、言語間音声表現学習用の機械学習 (以下「ML」) モデルであり、128 言語にわたる 400,000 時間を超える公開音声オーディオでトレーニングされました。リリース時に、このモデルは、53 言語にわたる約 50,000 時間の音声オーディオでトレーニングされた Meta AI のXLSR-53言語間モデルを飛躍的に上回りました。
このガイドでは、 Kaggle Notebookを使用して、自動音声認識 (「ASR」) 用に XLS-R を微調整する手順について説明します。モデルはチリのスペイン語で微調整されますが、一般的な手順に従って、必要なさまざまな言語で XLS-R を微調整できます。
微調整されたモデルで推論を実行する方法については、付属のチュートリアルで説明されており、このガイドは 2 部構成の第 1 部となります。この微調整ガイドは少し長くなったため、推論に特化した別のガイドを作成することにしました。
すでに ML のバックグラウンドがあり、基本的な ASR の概念を理解していることが前提となります。初心者はビルド手順を理解しにくい場合があります。
2020 年に導入されたオリジナルの wav2vec2 モデルは、960 時間のLibrispeechデータセットの音声と約 53,200 時間のLibriVoxデータセットの音声で事前トレーニングされました。リリース時には、9,500 万のパラメータを持つBASEモデルと 3 億 1,700 万のパラメータを持つLARGEモデルの 2 つのモデル サイズが利用可能でした。
一方、XLS-R は、5 つのデータセットからの多言語音声オーディオで事前トレーニングされました。
XLS-R モデルには、3 億個のパラメータを持つXLS-R (0.3B) 、10 億個のパラメータを持つXLS-R (1B) 、20 億個のパラメータを持つXLS-R (2B)の 3 つがあります。このガイドでは、XLS-R (0.3B) モデルを使用します。
wav2vev2モデルを微調整する方法について素晴らしい記事がいくつかあり、おそらくこれはある種の「ゴールド スタンダード」でしょう。もちろん、ここでの一般的なアプローチは、他のガイドで紹介されているものと似ています。次のようになります。
ただし、このガイドと他のガイドの間には 3 つの重要な違いがあります。
ガイドを完了するには、次のものが必要です。
ノートブックの構築を始める前に、すぐ下の 2 つのサブセクションを確認すると役立つ場合があります。これらのサブセクションの内容は次のとおりです。
はじめにで述べたように、XLS-Rモデルはチリのスペイン語で微調整されます。具体的なデータセットは、Guevara-Rukozらが開発したチリのスペイン語音声データセットです。これはOpenSLRからダウンロードできます。データセットは、(1)チリの男性話者の2,636の音声録音と(2)チリの女性話者の1,738の音声録音の2つのサブデータセットで構成されています。
各サブデータセットには、 line_index.tsv
インデックス ファイルが含まれています。各インデックス ファイルの各行には、オーディオ ファイル名と、関連ファイル内のオーディオの転写のペアが含まれています。例:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
便宜上、チリのスペイン語音声データセットを Kaggle にアップロードしました。チリの男性話者の録音用に 1 つの Kaggle データセットがあり、チリの女性話者の録音用に 1 つの Kaggle データセットがあります。これらの Kaggle データセットは、このガイドの手順に従って構築する Kaggle ノートブックに追加されます。
WER は、自動音声認識モデルのパフォーマンスを測定するために使用できる 1 つの指標です。WER は、テキスト予測がテキスト参照にどれだけ近いかを測定するメカニズムを提供します。WER は、次の 3 種類のエラーを記録することでこれを実現します。
置換 ( S
): 予測に参照内の類似語とは異なる単語が含まれている場合に、置換エラーが記録されます。たとえば、予測で参照内の単語のスペルが間違っている場合に発生します。
削除( D
):予測に参照に存在しない単語が含まれている場合、削除エラーが記録されます。
挿入( I
):予測に参照に存在する単語が含まれていない場合、挿入エラーが記録されます。
明らかに、WER は単語レベルで機能します。WER メトリックの式は次のとおりです。
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
スペイン語での簡単な WER の例は次のとおりです。
prediction: "Él está saliendo." reference: "Él está saltando."
表は予測の誤差を視覚化するのに役立ちます。
文章 | 単語1 | 単語2 | 単語3 |
---|---|---|---|
予測 | エル | そうです | サリエンド |
参照 | エル | そうです | サルタンド |
| 正しい | 正しい | 代替 |
予測には、置換エラーが 1 つ、削除エラーが 0 つ、挿入エラーが 0 つ含まれています。したがって、この例の WER は次のようになります。
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
単語エラー率が必ずしも具体的なエラーが何であるかを教えてくれるとは限らないことは明らかです。上記の例では、WER は予測されたテキストのWORD 3にエラーが含まれていることを識別しますが、予測で文字iとeが間違っていることは教えてくれません。文字エラー率 (CER) などの他の指標は、より正確なエラー分析に使用できます。
これで、微調整ノートブックの構築を開始する準備が整いました。
Kaggle Notebook は、WandB API キーを使用してトレーニング実行データを WandB に送信するように設定する必要があります。そのためには、キーをコピーする必要があります。
www.wandb.com
で WandB にログインします。www.wandb.ai/authorize
に移動します。
xls-r-300m-chilean-spanish-asr
を使用します。WandB API キーを安全に保存するために、 Kaggle Secret が使用されます。
WANDB_API_KEY
を入力し、値に WandB API キーを入力します。WANDB_API_KEY
ラベル フィールドの左側にある[添付]チェックボックスがオンになっていることを確認します。チリのスペイン語音声データセットは、 2 つの異なるデータセットとして Kaggle にアップロードされました。
これら両方のデータセットを Kaggle ノートブックに追加します。
次の 32 のサブステップでは、微調整ノートブックの 32 個のセルそれぞれを順番に構築します。
微調整ノートブックの最初のセルは依存関係をインストールします。最初のセルを次のように設定します。
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
パッケージを最新バージョンにアップグレードします。torchaudio torchaudio
、オーディオ ファイルを読み込み、オーディオ データを再サンプリングするために使用されます。Datasets
ライブラリのload_metric
メソッドを使用するために必要なjiwer
パッケージをインストールします。2 番目のセルは必要な Python パッケージをインポートします。2 番目のセルを次のように設定します。
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
ライブラリと関連するWav2Vec2*
クラスが、微調整に使用される機能のバックボーンを提供していることは言及する価値があります。3 番目のセルは、HuggingFace WER 評価メトリックをインポートします。3 番目のセルを次のように設定します。
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
4 番目のセルは、ステップ 2.2で設定したWANDB_API_KEY
シークレットを取得します。4 番目のセルを次のように設定します。
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
5 番目のセルは、ノートブック全体で使用される定数を設定します。5 番目のセルを次のように設定します。
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
6 番目のセルは、データセット インデックス ファイル (上記のトレーニング データセットのサブセクションを参照) の読み取り、および転写テキストのクリーニングと語彙の作成を行うユーティリティ メソッドを定義します。6 番目のセルを次のように設定します。
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
read_index_file_data
メソッドは、 line_index.tsv
データセット インデックス ファイルを読み取り、オーディオ ファイル名と転写データを含むリストのリストを生成します。例:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
メソッドは、ステップ 3.5で設定されたNUM_LOAD_FROM_EACH_SET
定数を使用して、リスト インデックス ファイルのデータを切り捨てます。具体的には、 NUM_LOAD_FROM_EACH_SET
定数は、各データセットからロードするオーディオ サンプルの数を指定するために使用されます。このガイドでは、この数は1600
に設定されており、最終的に合計3200
のオーディオ サンプルがロードされることを意味します。すべてのサンプルをロードするには、 NUM_LOAD_FROM_EACH_SET
を文字列値all
に設定します。clean_text
メソッドは、ステップ 3.5でSPECIAL_CHARS
に割り当てられた正規表現で指定された文字を各テキスト転写から削除するために使用されます。句読点を含むこれらの文字は、オーディオ機能とテキスト転写間のマッピングを学習するようにモデルをトレーニングするときに意味的な価値を提供しないため、削除できます。create_vocab
メソッドは、クリーンなテキスト転写から語彙を作成します。簡単に言えば、クリーンなテキスト転写のセットからすべての一意の文字を抽出します。生成された語彙の例は、ステップ 3.14で確認できます。7 番目のセルは、 torchaudio
を使用してオーディオ データを読み込み、再サンプリングするユーティリティ メソッドを定義します。7 番目のセルを次のように設定します。
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
メソッドは、指定されたオーディオ ファイルを読み込み、オーディオ データのtorch.Tensor
多次元マトリックスとオーディオのサンプリング レートを返します。トレーニング データ内のすべてのオーディオ ファイルのサンプリング レートは48000
Hz です。この「元の」サンプリング レートは、ステップ 3.5の定数ORIG_SAMPLING_RATE
によって取得されます。resample
メソッドは、オーディオ データをサンプリング レート48000
から16000
にダウンサンプリングするために使用されます。wav2vec2 は、 16000
Hz でサンプリングされたオーディオで事前トレーニングされています。したがって、微調整に使用するオーディオはすべて同じサンプリング レートである必要があります。この場合、オーディオ サンプルは48000
Hz から16000
Hz にダウンサンプリングする必要があります。16000 Hz 16000
、ステップ 3.5の定数TGT_SAMPLING_RATE
によって取得されます。8 番目のセルは、オーディオとトランスクリプション データを処理するユーティリティ メソッドを定義します。8 番目のセルを次のように設定します。
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
メソッドは、提供されたトレーニング サンプルから入力値を返します。process_target_text
メソッドは、各テキスト転写をラベルのリスト、つまり語彙内の文字を参照するインデックスのリストとしてエンコードします。サンプルのエンコードは、ステップ 3.15で確認できます。9 番目のセルは、最後のユーティリティ メソッド セルであり、参照転写と予測転写の間の単語エラー率を計算するメソッドが含まれています。9 番目のセルを次のように設定します。
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
10 番目のセルは、ステップ 3.6で定義されたread_index_file_data
メソッドを使用して、男性話者の録音と女性話者の録音のトレーニング データ インデックス ファイルを読み取ります。10 番目のセルを次のように設定します。
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
11 番目のセルは、ステップ 3.6で定義されたtruncate_training_dataset
メソッドを使用してトレーニング データ リストを切り捨てます。11 番目のセルを次のように設定します。
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
定数は、各データセットから保持するサンプルの数を定義します。このガイドでは、定数は1600
に設定されており、合計3200
サンプルになります。12 番目のセルは、切り捨てられたトレーニング データ リストを結合します。12 番目のセルを次のように設定します。
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
13 番目のセルは、各トレーニング データ サンプルを反復処理し、ステップ 3.6で定義されたclean_text
メソッドを使用して、関連する転写テキストをクリーンアップします。13 番目のセルを次のように設定します。
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
14 番目のセルは、前のステップでクリーンアップされた転写とステップ 3.6で定義されたcreate_vocab
メソッドを使用して語彙を作成します。14 番目のセルを次のように設定します。
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
語彙は、文字をキーとし、語彙インデックスを値とする辞書として保存されます。
vocab_dict
を印刷すると、次の出力が生成されます。
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
15 番目のセルは、語彙に単語区切り文字|
を追加します。15 番目のセルを次のように設定します。
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
単語区切り文字は、テキスト転写をラベルのリストとしてトークン化するときに使用されます。具体的には、単語の終わりを定義するために使用され、ステップ 3.17で説明するように、 Wav2Vec2CTCTokenizer
クラスを初期化するときに使用されます。
たとえば、次のリストは、ステップ3.14の語彙を使用してno te entiendo nada
をエンコードします。
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
16 番目のセルは語彙をファイルにダンプします。16 番目のセルを次のように設定します。
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
クラスを初期化するために使用されます。17 番目のセルは、 Wav2Vec2CTCTokenizer
のインスタンスを初期化します。17 番目のセルを次のように設定します。
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
トークナイザーは、テキストの転写をエンコードし、ラベルのリストをテキストにデコードするために使用されます。
tokenizer
、 unk_token
に[UNK]
が割り当てられ、 pad_token
に[PAD]
が割り当てられて初期化されることに注意してください。前者はテキスト転写内の不明なトークンを表すために使用され、後者は異なる長さの転写のバッチを作成するときに転写を埋め込むために使用されます。これらの 2 つの値は、トークナイザーによって語彙に追加されます。
このステップでトークナイザーを初期化すると、語彙に 2 つの追加トークン、つまりそれぞれ文の始まりと終わりを区別するために使用される<s>
と/</s>
も追加されます。
このステップでは、ステップ 3.15で語彙に文字を追加したことに応じて、パイプ記号が単語の終わりを区切るために使用されることを反映して、 |
がword_delimiter_token
に明示的に割り当てられます。 |
記号はword_delimiter_token
のデフォルト値です。したがって、明示的に設定する必要はありませんでしたが、わかりやすくするために設定しました。
word_delimiter_token
と同様に、パイプ記号|
がテキスト転写内の空白文字の置き換えに使用されることを反映して、 replace_word_delimiter_char
に 1 つのスペースが明示的に割り当てられます。空白はreplace_word_delimiter_char
のデフォルト値です。したがって、これも明示的に設定する必要はありませんでしたが、わかりやすくするために設定しました。
tokenizer
のget_vocab()
メソッドを呼び出すと、完全なトークナイザー語彙を印刷できます。
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
18 番目のセルは、 Wav2Vec2FeatureExtractor
のインスタンスを初期化します。18 番目のセルを次のように設定します。
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
初期化子に渡されるパラメータ値はすべてデフォルト値ですが、 return_attention_mask
はデフォルトでFalse
に設定されます。デフォルト値はわかりやすくするために表示/渡されています。feature_size
パラメータは、入力特徴(つまり、オーディオ データ特徴)の次元サイズを指定します。このパラメータのデフォルト値は1
です。sampling_rate
、オーディオデータをデジタル化するサンプリングレートを特徴抽出器に伝えます。ステップ 3.7で説明したように、wav2vec2 は16000
Hz でサンプリングされたオーディオで事前トレーニングされているため、このパラメータのデフォルト値は16000
です。padding_value
パラメータは、異なる長さのオーディオ サンプルをバッチ処理するときに必要な、オーディオ データのパディングに使用する値を指定します。デフォルト値は0.0
です。do_normalize
、入力データを標準正規分布に変換するかどうかを指定するために使用されます。デフォルト値はTrue
です。Wav2Vec2FeatureExtractor クラスのドキュメントには、「[正規化] により、一部のモデルのパフォーマンスが大幅に向上する可能性があります」と記載されていますWav2Vec2FeatureExtractor
return_attention_mask
パラメータは、アテンション マスクを渡すかどうかを指定します。このユース ケースでは、値はTrue
に設定されています。19 番目のセルはWav2Vec2Processor
のインスタンスを初期化します。19 番目のセルを次のように設定します。
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Wav2Vec2Processor
クラスは、ステップ 3.17とステップ 3.18のtokenizer
とfeature_extractor
それぞれ 1 つのプロセッサに結合します。
Wav2Vec2Processor
クラス インスタンスでsave_pretrained
メソッドを呼び出すことによって、プロセッサ構成を保存できることに注意してください。
processor.save_pretrained(OUTPUT_DIR_PATH)
20 番目のセルは、 all_training_samples
リストで指定された各オーディオ ファイルを読み込みます。20 番目のセルを次のように設定します。
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
として返され、辞書のリストとしてall_input_data
に保存されます。各辞書には、特定のサンプルのオーディオ データと、オーディオのテキスト転写が含まれています。read_audio_data
メソッドはオーディオ データのサンプリング レートも返すことに注意してください。このユース ケースではすべてのオーディオ ファイルのサンプリング レートが48000
Hz であることがわかっているため、この手順ではサンプリング レートは無視されます。all_input_data
Pandas DataFrame に変換する21 番目のセルは、 all_input_data
リストを Pandas DataFrame に変換して、データの操作を容易にします。21 番目のセルを次のように設定します。
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
22 番目のセルは、ステップ 3.19で初期化されたprocessor
を使用して、各オーディオ データ サンプルから特徴を抽出し、各テキスト転写をラベルのリストとしてエンコードします。22 番目のセルを次のように設定します。
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
23 番目のセルは、ステップ 3.5のSPLIT_PCT
定数を使用して、 all_input_data_df
DataFrame をトレーニング データセットと評価 (検証) データセットに分割します。23 番目のセルを次のように設定します。
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
値は0.10
です。これは、すべての入力データの 10% が評価用に保持され、データの 90% がトレーニング/微調整に使用されることを意味します。Dataset
オブジェクトに変換する24 番目のセルは、 train_data_df
およびvalid_data_df
DataFrame をDataset
オブジェクトに変換します。24 番目のセルを次のように設定します。
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
ステップ 3.30でわかるように、 Dataset
オブジェクトは HuggingFace Trainer
クラス インスタンスによって使用されます。
これらのオブジェクトには、データセット自体だけでなく、データセットに関するメタデータも含まれています。
train_data
とvalid_data
を印刷して、両方のDataset
オブジェクトのメタデータを表示できます。
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
25 番目のセルは、事前トレーニング済みの XLS-R (0.3) モデルを初期化します。25 番目のセルを次のように設定します。
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
Wav2Vec2ForCTC
で呼び出されるfrom_pretrained
メソッドは、指定されたモデルの事前トレーニング済みの重みを読み込むことを指定します。MODEL
定数はステップ3.5で指定され、XLS-R(0.3)モデルを反映してfacebook/wav2vec2-xls-r-300m
に設定されました。ctc_loss_reduction
パラメータは、コネクショニスト時間分類 ("CTC") 損失関数の出力に適用する削減のタイプを指定します。CTC 損失は、連続入力 (この場合はオーディオ データ) とターゲット シーケンス (この場合はテキスト転写) 間の損失を計算するために使用されます。値をmean
に設定すると、入力バッチの出力損失がターゲットの長さで除算されます。次に、バッチ全体の平均が計算され、削減が損失値に適用されます。pad_token_id
、バッチ処理時にパディングに使用するトークンを指定します。これは、ステップ3.17でトークナイザーを初期化するときに設定された[PAD]
idに設定されます。vocab_size
パラメータは、モデルの語彙サイズを定義します。これは、ステップ 3.17でトークナイザーを初期化した後の語彙サイズであり、ネットワークの前方部分の出力層ノードの数を反映します。26 番目のセルは、特徴抽出器の事前トレーニング済みの重みを固定します。26 番目のセルを次のように設定します。
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
27 番目のセルは、 Trainer
インスタンスに渡されるトレーニング引数を初期化します。27 番目のセルを次のように設定します。
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
クラスは100 を超えるパラメータを受け入れます。save_safetensors
パラメータがFalse
の場合、 safetensors
形式を使用する代わりに、微調整されたモデルをpickle
ファイルに保存することを指定します。group_by_length
パラメータがTrue
の場合、ほぼ同じ長さのサンプルをグループ化する必要があることを示します。これにより、パディングが最小限に抑えられ、トレーニングの効率が向上します。per_device_train_batch_size
、トレーニング ミニバッチあたりのサンプル数を設定します。このパラメータは、ステップ 3.5で割り当てられたTRAIN_BATCH_SIZE
定数によって18
に設定されます。これは、エポックあたり 160 ステップを意味します。per_device_eval_batch_size
、評価(ホールドアウト)ミニバッチあたりのサンプル数を設定します。このパラメータは、ステップ 3.5で割り当てられたEVAL_BATCH_SIZE
定数によって10
に設定されます。num_train_epochs
はトレーニング エポックの数を設定します。このパラメータは、ステップ 3.5で割り当てられたTRAIN_EPOCHS
定数によって30
に設定されます。これは、トレーニング中に合計 4,800 ステップを意味します。gradient_checkpointing
パラメータがTrue
の場合、勾配計算をチェックポイントすることでメモリを節約できますが、逆方向パスの速度は低下します。evaluation_strategy
パラメータをsteps
に設定すると、トレーニング中にパラメータeval_steps
で指定された間隔で評価が実行され、ログに記録されます。logging_strategy
パラメータをsteps
に設定すると、トレーニング実行の統計がパラメータlogging_steps
で指定された間隔で記録されることを意味します。save_strategy
パラメータをsteps
に設定すると、微調整されたモデルのチェックポイントが、パラメータsave_steps
で指定された間隔で保存されることを意味します。eval_steps
、ホールドアウト データの評価間のステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたEVAL_STEPS
定数によって100
に設定されます。save_steps
、微調整されたモデルのチェックポイントが保存されるまでのステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたSAVE_STEPS
定数によって3200
に設定されます。logging_steps
、トレーニング実行統計のログ間のステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたLOGGING_STEPS
定数によって100
に設定されます。learning_rate
パラメータは初期学習率を設定します。このパラメータは、ステップ3.5で割り当てられたLEARNING_RATE
定数によって1e-4
に設定されます。warmup_steps
パラメータは、学習率を 0 からlearning_rate
で設定された値まで線形にウォームアップするステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたWARMUP_STEPS
定数によって800
に設定されます。28 番目のセルは、入力シーケンスとターゲット シーケンスを動的にパディングするためのロジックを定義します。28 番目のセルを次のように設定します。
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
インスタンスにミニバッチで渡されます。入力シーケンスとラベル シーケンスの長さは各ミニバッチで異なるため、一部のシーケンスはすべて同じ長さになるようにパディングする必要があります。DataCollatorCTCWithPadding
クラスは、ミニバッチ データを動的にパディングします。 padding
パラメータをTrue
に設定すると、短いオーディオ入力機能シーケンスとラベル シーケンスの長さがミニバッチ内の最長シーケンスと同じになるように指定されます。0.0
でパディングされます。-100
に置き換えられ、WER メトリックを計算するときにこれらのラベルは無視されます。29 番目のセルは、前の手順で定義したデータ コレータのインスタンスを初期化します。29 番目のセルを次のように設定します。
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
30 番目のセルは、 Trainer
クラスのインスタンスを初期化します。30 番目のセルを次のように設定します。
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
クラスは次のように初期化されます。model
。train_data
Dataset
オブジェクト。valid_data
Dataset
オブジェクト。tokenizer
パラメータは、 processor.feature_extractor
に割り当てられ、 data_collator
と連携して、各ミニバッチの最大長の入力に自動的に入力を埋め込みます。31 番目のセルは、 Trainer
クラス インスタンスのtrain
メソッドを呼び出して、モデルを微調整します。31 番目のセルを次のように設定します。
### CELL 31: Finetune the model ### trainer.train()
32 番目のセルは最後のノートブック セルです。Trainer Trainer
のsave_model
メソッドを呼び出して、微調整されたモデルを保存します。32 番目のセルを次のように設定します。
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
ノートブックのすべてのセルが構築されたので、微調整を開始します。
Kaggle Notebook をNVIDIA GPU P100アクセラレータで実行するように設定します。
ノートブックを Kaggle にコミットします。
WandB アカウントにログインし、関連する実行を見つけて、トレーニング実行データを監視します。
NVIDIA GPU P100 アクセラレータを使用すると、30 エポックを超えるトレーニングには約 5 時間かかります。ホールドアウト データの WER は、トレーニング終了時に約 0.15 に低下します。最先端の結果ではありませんが、微調整されたモデルは、多くのアプリケーションで十分に役立ちます。
微調整されたモデルは、ステップ 3.5で指定された定数OUTPUT_DIR_PATH
で指定された Kaggle ディレクトリに出力されます。モデル出力には次のファイルが含まれる必要があります。
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
これらのファイルはローカルにダウンロードできます。さらに、モデル ファイルを使用して新しいKaggle モデルを作成することもできます。Kaggleモデルは、付属の推論ガイドとともに使用され、微調整されたモデルで推論を実行します。
wav2vec2 XLS-R の微調整おめでとうございます。これらの一般的な手順を使用して、必要な他の言語でモデルを微調整できることを覚えておいてください。このガイドで生成された微調整済みモデルで推論を実行するのは非常に簡単です。推論手順については、このガイドとは別のコンパニオン ガイドで概説します。コンパニオン ガイドを見つけるには、私の HackerNoon ユーザー名を検索してください。