Meta AI đã giới thiệu wav2vec2 XLS-R ("XLS-R") vào cuối năm 2021. XLS-R là mô hình máy học ("ML") để học cách biểu đạt giọng nói đa ngôn ngữ; và nó đã được đào tạo trên hơn 400.000 giờ âm thanh giọng nói có sẵn công khai trên 128 ngôn ngữ. Sau khi phát hành, mô hình này thể hiện một bước nhảy vọt so với mô hình đa ngôn ngữ XLSR-53 của Meta AI, mô hình này đã được đào tạo về khoảng 50.000 giờ âm thanh lời nói trên 53 ngôn ngữ.
Hướng dẫn này giải thích các bước để tinh chỉnh XLS-R để nhận dạng giọng nói tự động ("ASR") bằng Sổ tay Kaggle . Mô hình này sẽ được tinh chỉnh bằng tiếng Tây Ban Nha Chile, nhưng bạn có thể làm theo các bước chung để tinh chỉnh XLS-R trên các ngôn ngữ khác nhau mà bạn mong muốn.
Việc chạy suy luận trên mô hình tinh chỉnh sẽ được mô tả trong phần hướng dẫn đi kèm, khiến hướng dẫn này trở thành phần đầu tiên trong hai phần. Tôi quyết định tạo một hướng dẫn riêng dành riêng cho suy luận vì hướng dẫn tinh chỉnh này hơi dài.
Giả sử bạn đã có nền tảng ML hiện có và hiểu các khái niệm ASR cơ bản. Người mới bắt đầu có thể gặp khó khăn khi theo dõi/hiểu các bước xây dựng.
Mô hình wav2vec2 ban đầu được giới thiệu vào năm 2020 đã được đào tạo trước trên 960 giờ âm thanh giọng nói của tập dữ liệu Librispeech và ~53.200 giờ âm thanh giọng nói của tập dữ liệu LibriVox . Sau khi phát hành, có hai kích cỡ mô hình: mô hình BASE với 95 triệu tham số và mô hình LARGE với 317 triệu tham số.
Mặt khác, XLS-R đã được huấn luyện trước về âm thanh lời nói đa ngôn ngữ từ 5 bộ dữ liệu:
Có 3 mẫu XLS-R: XLS-R (0,3B) với 300 triệu thông số, XLS-R (1B) với 1 tỷ thông số và XLS-R (2B) với 2 tỷ thông số. Hướng dẫn này sẽ sử dụng mẫu XLS-R (0,3B).
Có một số bài viết hay về cách tinh chỉnh các mô hình wav2vev2 , có lẽ mô hình này là một loại "tiêu chuẩn vàng". Tất nhiên, cách tiếp cận chung ở đây bắt chước những gì bạn sẽ tìm thấy trong các hướng dẫn khác. Bạn sẽ:
Tuy nhiên, có ba điểm khác biệt chính giữa hướng dẫn này và các hướng dẫn khác:
Để hoàn thành hướng dẫn, bạn sẽ cần phải có:
Trước khi bắt đầu xây dựng sổ ghi chép, bạn có thể xem lại hai phần phụ ngay bên dưới. Họ mô tả:
Như đã đề cập trong phần Giới thiệu , mẫu XLS-R sẽ được tinh chỉnh bằng tiếng Tây Ban Nha ở Chile. Tập dữ liệu cụ thể là Tập dữ liệu lời nói tiếng Tây Ban Nha ở Chile được phát triển bởi Guevara-Rukoz et al. Nó có sẵn để tải xuống trên OpenSLR . Bộ dữ liệu bao gồm hai bộ dữ liệu phụ: (1) 2.636 bản ghi âm của những người nói tiếng Chile là nam và (2) 1.738 bản ghi âm của những người nói tiếng Chile là nữ.
Mỗi tập dữ liệu con bao gồm một tệp chỉ mục line_index.tsv
. Mỗi dòng của mỗi tệp chỉ mục chứa một cặp tên tệp âm thanh và bản ghi âm của tệp được liên kết, ví dụ:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
Tôi đã tải Bộ dữ liệu lời nói tiếng Tây Ban Nha tiếng Chile lên Kaggle để thuận tiện. Có một tập dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nam giới và một bộ dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nữ . Các bộ dữ liệu Kaggle này sẽ được thêm vào Sổ tay Kaggle mà bạn sẽ xây dựng theo các bước trong hướng dẫn này.
WER là một số liệu có thể được sử dụng để đo lường hiệu suất của các mô hình nhận dạng giọng nói tự động. WER cung cấp một cơ chế để đo lường mức độ gần gũi của dự đoán văn bản với tham chiếu văn bản. WER thực hiện điều này bằng cách ghi lại 3 loại lỗi:
thay thế ( S
): Lỗi thay thế được ghi lại khi dự đoán chứa một từ khác với từ tương tự trong tham chiếu. Ví dụ: điều này xảy ra khi dự đoán viết sai chính tả một từ trong tài liệu tham khảo.
xóa ( D
): Lỗi xóa được ghi lại khi dự đoán chứa một từ không có trong tham chiếu.
phần chèn ( I
): Lỗi chèn được ghi lại khi dự đoán không chứa từ nào có trong tham chiếu.
Rõ ràng, WER hoạt động ở cấp độ từ. Công thức cho số liệu WER như sau:
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
Một ví dụ WER đơn giản bằng tiếng Tây Ban Nha như sau:
prediction: "Él está saliendo." reference: "Él está saltando."
Một bảng giúp hình dung các lỗi trong dự đoán:
CHỮ | TỪ 1 | TỪ 2 | TỪ 3 |
---|---|---|---|
sự dự đoán | Él | está | saliendo |
thẩm quyền giải quyết | Él | está | muối |
| Chính xác | Chính xác | thay thế |
Dự đoán có 1 lỗi thay thế, 0 lỗi xóa và 0 lỗi chèn. Vì vậy, WER cho ví dụ này là:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Rõ ràng là Tỷ lệ lỗi từ không nhất thiết cho chúng ta biết những lỗi cụ thể nào đang tồn tại. Trong ví dụ trên, WER xác định rằng WORD 3 có lỗi trong văn bản được dự đoán, nhưng nó không cho chúng ta biết rằng các ký tự i và e sai trong dự đoán. Các số liệu khác, chẳng hạn như Tỷ lệ lỗi ký tự ("CER"), có thể được sử dụng để phân tích lỗi chính xác hơn.
Bây giờ bạn đã sẵn sàng để bắt đầu xây dựng sổ ghi chép tinh chỉnh.
Sổ tay Kaggle của bạn phải được định cấu hình để gửi dữ liệu chạy đào tạo tới WandB bằng khóa API WandB của bạn. Để làm được điều đó, bạn cần sao chép nó.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.Bí mật Kaggle sẽ được sử dụng để lưu trữ khóa API WandB của bạn một cách an toàn.
WANDB_API_KEY
vào trường Nhãn và nhập khóa API WandB của bạn cho giá trị.WANDB_API_KEY
được chọn.Tập dữ liệu giọng nói tiếng Tây Ban Nha ở Chile đã được tải lên Kaggle dưới dạng 2 tập dữ liệu riêng biệt:
Thêm cả hai bộ dữ liệu này vào Sổ tay Kaggle của bạn.
32 bước phụ sau đây sẽ xây dựng từng ô trong số 32 ô của sổ ghi chép tinh chỉnh theo thứ tự.
Ô đầu tiên của sổ ghi chép tinh chỉnh sẽ cài đặt các phần phụ thuộc. Đặt ô đầu tiên thành:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
lên phiên bản mới nhất. torchaudio
sẽ được sử dụng để tải các tệp âm thanh và lấy mẫu lại dữ liệu âm thanh.jiwer
được yêu cầu để sử dụng phương thức load_metric
của thư viện HuggingFace Datasets
được sử dụng sau này.Việc nhập ô thứ hai yêu cầu các gói Python. Đặt ô thứ hai thành:
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
HuggingFace và các lớp Wav2Vec2*
liên quan cung cấp nền tảng cho chức năng được sử dụng để tinh chỉnh.Ô thứ ba nhập số liệu đánh giá HuggingFace WER. Đặt ô thứ ba thành:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
Ô thứ tư lấy bí mật WANDB_API_KEY
của bạn đã được đặt ở Bước 2.2 . Đặt ô thứ tư thành:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
Ô thứ năm đặt các hằng số sẽ được sử dụng trong toàn bộ sổ ghi chép. Đặt ô thứ năm thành:
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
Ô thứ sáu xác định các phương thức tiện ích để đọc các tệp chỉ mục tập dữ liệu (xem phần phụ Tập dữ liệu huấn luyện ở trên), cũng như để làm sạch văn bản phiên âm và tạo từ vựng. Đặt ô thứ sáu thành:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
Phương thức read_index_file_data
đọc tệp chỉ mục tập dữ liệu line_index.tsv
và tạo danh sách các danh sách có tên tệp âm thanh và dữ liệu phiên âm, ví dụ:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
cắt bớt dữ liệu tệp chỉ mục danh sách bằng cách sử dụng hằng NUM_LOAD_FROM_EACH_SET
được đặt ở Bước 3.5 . Cụ thể, hằng số NUM_LOAD_FROM_EACH_SET
được sử dụng để chỉ định số lượng mẫu âm thanh cần tải từ mỗi tập dữ liệu. Vì mục đích của hướng dẫn này, con số được đặt ở 1600
, nghĩa là cuối cùng tổng cộng 3200
mẫu âm thanh sẽ được tải. Để tải tất cả các mẫu, hãy đặt NUM_LOAD_FROM_EACH_SET
thành giá trị chuỗi all
.clean_text
được sử dụng để loại bỏ từng phiên âm văn bản của các ký tự được chỉ định bởi biểu thức chính quy được gán cho SPECIAL_CHARS
trong Bước 3.5 . Những ký tự này, bao gồm cả dấu câu, có thể bị loại bỏ vì chúng không cung cấp bất kỳ giá trị ngữ nghĩa nào khi đào tạo mô hình để tìm hiểu ánh xạ giữa các tính năng âm thanh và bản chép lời văn bản.create_vocab
tạo từ vựng từ các bản chép lại văn bản rõ ràng. Đơn giản, nó trích xuất tất cả các ký tự duy nhất từ bộ phiên âm văn bản đã được làm sạch. Bạn sẽ thấy ví dụ về từ vựng được tạo ở Bước 3.14 . Ô thứ bảy xác định các phương thức tiện ích sử dụng torchaudio
để tải và lấy mẫu lại dữ liệu âm thanh. Đặt ô thứ bảy thành:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
tải một tệp âm thanh được chỉ định và trả về một ma trận đa chiều torch.Tensor
của dữ liệu âm thanh cùng với tốc độ lấy mẫu của âm thanh. Tất cả các tệp âm thanh trong dữ liệu huấn luyện đều có tốc độ lấy mẫu là 48000
Hz. Tốc độ lấy mẫu "gốc" này được ghi lại bằng hằng số ORIG_SAMPLING_RATE
trong Bước 3.5 .resample
được sử dụng để giảm mẫu dữ liệu âm thanh từ tốc độ lấy mẫu từ 48000
xuống 16000
. wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở 16000
Hz. Theo đó, bất kỳ âm thanh nào được sử dụng để tinh chỉnh đều phải có cùng tốc độ lấy mẫu. Trong trường hợp này, mẫu âm thanh phải được giảm tần số lấy mẫu từ 48000
Hz xuống 16000
Hz. 16000
Hz được ghi lại bằng hằng số TGT_SAMPLING_RATE
ở Bước 3.5 .Ô thứ tám xác định các phương thức tiện ích xử lý dữ liệu âm thanh và phiên âm. Đặt ô thứ tám thành:
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
trả về giá trị đầu vào từ mẫu đào tạo được cung cấp.process_target_text
mã hóa mỗi bản phiên âm văn bản dưới dạng danh sách các nhãn - tức là danh sách các chỉ mục đề cập đến các ký tự trong từ vựng. Bạn sẽ thấy mã hóa mẫu ở Bước 3.15 .Ô thứ chín là ô phương thức tiện ích cuối cùng và chứa phương pháp tính Tỷ lệ lỗi từ giữa bản phiên âm tham chiếu và bản phiên âm dự đoán. Đặt ô thứ chín thành:
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
Ô thứ mười đọc các tệp chỉ mục dữ liệu huấn luyện cho bản ghi của người nói nam và bản ghi của người nói nữ bằng phương pháp read_index_file_data
được xác định trong Bước 3.6 . Đặt ô thứ mười thành:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
Ô thứ mười một cắt bớt danh sách dữ liệu huấn luyện bằng phương pháp truncate_training_dataset
được xác định trong Bước 3.6 . Đặt ô thứ mười một thành:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
được đặt ở Bước 3.5 xác định số lượng mẫu cần giữ lại từ mỗi tập dữ liệu. Hằng số được đặt thành 1600
trong hướng dẫn này cho tổng số 3200
mẫu.Ô thứ mười hai kết hợp các danh sách dữ liệu huấn luyện bị cắt bớt. Đặt ô thứ mười hai thành:
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
Ô thứ mười ba lặp lại từng mẫu dữ liệu huấn luyện và xóa văn bản phiên mã liên quan bằng phương pháp clean_text
được xác định trong Bước 3.6 . Đặt ô thứ mười ba thành:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
Ô thứ mười bốn tạo từ vựng bằng cách sử dụng các bản phiên âm đã được làm sạch từ bước trước và phương thức create_vocab
được xác định ở Bước 3.6 . Đặt ô thứ mười bốn thành:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
Từ vựng được lưu trữ dưới dạng từ điển với các ký tự là khóa và chỉ mục từ vựng là giá trị.
Bạn có thể in vocab_dict
để tạo ra kết quả sau:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
Ô thứ mười lăm thêm ký tự phân cách từ |
đến từ vựng. Đặt ô thứ mười lăm thành:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
Ký tự phân cách từ được sử dụng khi mã hóa bản phiên âm văn bản dưới dạng danh sách nhãn. Cụ thể, nó được sử dụng để xác định phần cuối của một từ và nó được sử dụng khi khởi tạo lớp Wav2Vec2CTCTokenizer
, như sẽ thấy trong Bước 3.17 .
Ví dụ: danh sách sau đây mã hóa no te entiendo nada
bằng cách sử dụng từ vựng ở Bước 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
Ô thứ mười sáu chuyển từ vựng vào một tập tin. Đặt ô thứ mười sáu thành:
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
. Ô thứ mười bảy khởi tạo một phiên bản của Wav2Vec2CTCTokenizer
. Đặt ô thứ mười bảy thành:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
Trình mã thông báo được sử dụng để mã hóa bản phiên âm văn bản và giải mã danh sách nhãn trở lại văn bản.
Lưu ý rằng tokenizer
được khởi tạo với [UNK]
được gán cho unk_token
và [PAD]
được gán cho pad_token
, với mã trước đây được sử dụng để biểu thị các mã thông báo không xác định trong bản phiên âm văn bản và mã thông báo sau được sử dụng để đệm phiên âm khi tạo các lô phiên âm có độ dài khác nhau. Hai giá trị này sẽ được thêm vào từ vựng bằng mã thông báo.
Việc khởi tạo mã thông báo trong bước này cũng sẽ thêm hai mã thông báo bổ sung vào từ vựng, đó là <s>
và /</s>
, được sử dụng để phân định lần lượt phần đầu và phần cuối của câu.
|
được gán cho word_delimiter_token
một cách rõ ràng trong bước này để phản ánh rằng ký hiệu ống sẽ được sử dụng để phân định ranh giới cuối từ theo cách chúng ta thêm ký tự vào từ vựng trong Bước 3.15 . |
ký hiệu là giá trị mặc định cho word_delimiter_token
. Vì vậy, nó không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.
Tương tự như với word_delimiter_token
, một khoảng trắng được gán rõ ràng cho replace_word_delimiter_char
phản ánh rằng ký hiệu ống |
sẽ được sử dụng để thay thế các ký tự khoảng trống trong phiên âm văn bản. Khoảng trống là giá trị mặc định cho replace_word_delimiter_char
. Vì vậy, nó cũng không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.
Bạn có thể in toàn bộ từ vựng của tokenizer bằng cách gọi phương thức get_vocab()
trên tokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
Ô thứ mười tám khởi tạo một phiên bản của Wav2Vec2FeatureExtractor
. Đặt ô thứ mười tám thành:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
đều là các giá trị mặc định, ngoại trừ return_attention_mask
được mặc định là False
. Các giá trị mặc định được hiển thị/chuyển đi nhằm mục đích rõ ràng.feature_size
chỉ định kích thước kích thước của các tính năng đầu vào (tức là tính năng dữ liệu âm thanh). Giá trị mặc định của tham số này là 1
.sampling_rate
cho trình trích xuất tính năng biết tốc độ lấy mẫu mà tại đó dữ liệu âm thanh sẽ được số hóa. Như đã thảo luận ở Bước 3.7 , wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở 16000
Hz và do đó 16000
là giá trị mặc định cho tham số này.padding_value
chỉ định giá trị được sử dụng khi đệm dữ liệu âm thanh, theo yêu cầu khi phân nhóm các mẫu âm thanh có độ dài khác nhau. Giá trị mặc định là 0.0
.do_normalize
được sử dụng để chỉ định xem dữ liệu đầu vào có nên được chuyển đổi sang phân phối chuẩn chuẩn hay không. Giá trị mặc định là True
. Tài liệu lớp Wav2Vec2FeatureExtractor
lưu ý rằng "[chuẩn hóa] có thể giúp cải thiện đáng kể hiệu suất cho một số kiểu máy."return_attention_mask
chỉ định xem mặt nạ chú ý có được chuyển hay không. Giá trị được đặt thành True
cho trường hợp sử dụng này. Ô thứ mười chín khởi tạo một phiên bản của Wav2Vec2Processor
. Đặt ô thứ mười chín thành:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Lớp Wav2Vec2Processor
kết hợp tokenizer
và feature_extractor
từ Bước 3.17 và Bước 3.18 tương ứng vào một bộ xử lý duy nhất.
Lưu ý rằng cấu hình bộ xử lý có thể được lưu bằng cách gọi phương thức save_pretrained
trên phiên bản lớp Wav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
Ô thứ 20 tải từng tệp âm thanh được chỉ định trong danh sách all_training_samples
. Đặt ô thứ hai mươi thành:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
và được lưu trữ trong all_input_data
dưới dạng danh sách từ điển. Mỗi từ điển chứa dữ liệu âm thanh cho một mẫu cụ thể, cùng với bản phiên âm văn bản của âm thanh.read_audio_data
cũng trả về tốc độ lấy mẫu của dữ liệu âm thanh. Vì chúng ta biết rằng tốc độ lấy mẫu là 48000
Hz cho tất cả các tệp âm thanh trong trường hợp sử dụng này nên tốc độ lấy mẫu sẽ bị bỏ qua trong bước này.all_input_data
thành Pandas DataFrame Ô thứ 21 chuyển đổi danh sách all_input_data
thành Pandas DataFrame để giúp thao tác dữ liệu dễ dàng hơn. Đặt ô thứ 21 thành:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
Ô thứ 22 sử dụng processor
được khởi tạo ở Bước 3.19 để trích xuất các tính năng từ từng mẫu dữ liệu âm thanh và mã hóa từng bản phiên âm văn bản dưới dạng danh sách các nhãn. Đặt ô thứ hai mươi hai thành:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
Ô thứ hai mươi ba chia DataFrame all_input_data_df
thành các tập dữ liệu huấn luyện và đánh giá (xác thực) bằng cách sử dụng hằng số SPLIT_PCT
từ Bước 3.5 . Đặt ô thứ hai mươi ba thành:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
là 0.10
trong hướng dẫn này, nghĩa là 10% tất cả dữ liệu đầu vào sẽ được giữ lại để đánh giá và 90% dữ liệu sẽ được sử dụng để đào tạo/tinh chỉnh.Dataset
Ô thứ 24 chuyển đổi DataFrames train_data_df
và valid_data_df
thành các đối tượng Dataset
. Đặt ô thứ 24 thành:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Các đối tượng Dataset
được sử dụng bởi các phiên bản của lớp HuggingFace Trainer
, như bạn sẽ thấy trong Bước 3.30 .
Các đối tượng này chứa siêu dữ liệu về tập dữ liệu cũng như chính tập dữ liệu đó.
Bạn có thể in train_data
và valid_data
để xem siêu dữ liệu cho cả hai đối tượng Dataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
Ô thứ 25 khởi tạo mô hình XLS-R (0,3) đã được huấn luyện trước. Đặt ô thứ 25 thành:
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
from_pretrained
được gọi trên Wav2Vec2ForCTC
chỉ định rằng chúng ta muốn tải các trọng số đã được huấn luyện trước cho mô hình đã chỉ định.MODEL
được chỉ định ở Bước 3.5 và được đặt thành facebook/wav2vec2-xls-r-300m
phản ánh mô hình XLS-R (0.3).ctc_loss_reduction
chỉ định loại giảm để áp dụng cho đầu ra của hàm mất mát Phân loại thời gian kết nối ("CTC"). Suy hao CTC được sử dụng để tính toán suy hao giữa đầu vào liên tục, trong trường hợp này là dữ liệu âm thanh và chuỗi đích, trong trường hợp này là bản chép lại văn bản. Bằng cách đặt giá trị mean
, tổn thất đầu ra của một lô đầu vào sẽ được chia cho độ dài mục tiêu. Sau đó, giá trị trung bình của lô sẽ được tính toán và mức giảm được áp dụng cho các giá trị tổn thất.pad_token_id
chỉ định mã thông báo được sử dụng để đệm khi tạo khối. Nó được đặt thành id [PAD]
được đặt khi khởi tạo mã thông báo ở Bước 3.17 .vocab_size
xác định kích thước từ vựng của mô hình. Đó là kích thước từ vựng sau khi khởi tạo mã thông báo ở Bước 3.17 và phản ánh số lượng nút lớp đầu ra của phần chuyển tiếp của mạng.Ô thứ 26 đóng băng các trọng số đã được huấn luyện trước của bộ trích xuất đặc điểm. Đặt ô thứ hai mươi sáu thành:
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
Ô thứ 27 khởi tạo các đối số huấn luyện sẽ được chuyển đến phiên bản Trainer
. Đặt ô thứ hai mươi bảy thành:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
chấp nhận hơn 100 tham số .save_safetensors
khi False
chỉ định rằng mô hình đã tinh chỉnh sẽ được lưu vào tệp pickle
thay vì sử dụng định dạng safetensors
.group_by_length
khi True
cho biết rằng các mẫu có độ dài xấp xỉ nhau sẽ được nhóm lại với nhau. Điều này giảm thiểu phần đệm và cải thiện hiệu quả đào tạo.per_device_train_batch_size
đặt số lượng mẫu cho mỗi đợt đào tạo nhỏ. Tham số này được đặt thành 18
thông qua hằng số TRAIN_BATCH_SIZE
được chỉ định ở Bước 3.5 . Điều này ngụ ý 160 bước mỗi kỷ nguyên.per_device_eval_batch_size
đặt số lượng mẫu cho mỗi lô nhỏ đánh giá (tạm giữ). Tham số này được đặt thành 10
thông qua hằng số EVAL_BATCH_SIZE
được chỉ định ở Bước 3.5 .num_train_epochs
đặt số lượng kỷ nguyên đào tạo. Tham số này được đặt thành 30
thông qua hằng số TRAIN_EPOCHS
được chỉ định ở Bước 3.5 . Điều này ngụ ý tổng số 4.800 bước trong quá trình đào tạo.gradient_checkpointing
khi True
giúp tiết kiệm bộ nhớ bằng cách kiểm tra các phép tính gradient nhưng dẫn đến tốc độ lùi lại chậm hơn.evaluation_strategy
khi được đặt thành steps
có nghĩa là việc đánh giá sẽ được thực hiện và ghi lại trong quá trình đào tạo ở khoảng thời gian được chỉ định bởi tham số eval_steps
.logging_strategy
khi được đặt thành steps
có nghĩa là số liệu thống kê về lần chạy huấn luyện sẽ được ghi lại theo khoảng thời gian được chỉ định bởi tham logging_steps
.save_strategy
khi được đặt thành steps
có nghĩa là điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu trong khoảng thời gian được chỉ định bởi tham số save_steps
.eval_steps
đặt số bước giữa các lần đánh giá dữ liệu loại trừ. Tham số này được đặt thành 100
thông qua hằng số EVAL_STEPS
được gán ở Bước 3.5 .save_steps
đặt số bước sau đó điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu. Tham số này được đặt thành 3200
thông qua hằng số SAVE_STEPS
được gán ở Bước 3.5 .logging_steps
đặt số bước giữa các nhật ký thống kê về lần chạy tập luyện. Tham số này được đặt thành 100
thông qua hằng số LOGGING_STEPS
được gán ở Bước 3.5 .learning_rate
đặt tốc độ học ban đầu. Tham số này được đặt thành 1e-4
thông qua hằng số LEARNING_RATE
được gán ở Bước 3.5 .warmup_steps
đặt số bước để tăng tốc độ học tập một cách tuyến tính từ 0 đến giá trị do learning_rate
đặt. Tham số này được đặt thành 800
thông qua hằng số WARMUP_STEPS
được gán ở Bước 3.5 .Ô thứ hai mươi tám xác định logic cho các chuỗi mục tiêu và đầu vào đệm động. Đặt ô thứ hai mươi tám thành:
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
. Các cặp này sẽ được khởi tạo trong giây lát ở Bước 3.30 . Do các chuỗi đầu vào và chuỗi nhãn có độ dài khác nhau trong mỗi lô nhỏ nên một số chuỗi phải được đệm để chúng có cùng độ dài.DataCollatorCTCWithPadding
tự động đệm dữ liệu theo lô nhỏ. Thông số padding
khi được đặt thành True
chỉ định rằng chuỗi tính năng đầu vào âm thanh ngắn hơn và chuỗi nhãn phải có cùng độ dài với chuỗi dài nhất trong một lô nhỏ.0.0
được đặt khi khởi chạy trình trích xuất tính năng ở Bước 3.18 .-100
để các nhãn này bị bỏ qua khi tính chỉ số WER.Ô thứ 29 khởi tạo một phiên bản của bộ đối chiếu dữ liệu được xác định ở bước trước. Đặt ô thứ hai mươi chín thành:
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
Ô thứ ba mươi khởi tạo một thể hiện của lớp Trainer
. Đặt ô thứ ba mươi thành:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
được khởi tạo bằng:model
được huấn luyện trước được khởi tạo ở Bước 3.25 .Dataset
train_data
từ Bước 3.24 .Dataset
valid_data
từ Bước 3.24 .tokenizer
được gán cho processor.feature_extractor
và hoạt động với data_collator
để tự động đệm đầu vào vào đầu vào có độ dài tối đa của mỗi lô nhỏ. Ô thứ 31 gọi phương thức train
trên phiên bản lớp Trainer
để tinh chỉnh mô hình. Đặt ô thứ ba mươi mốt thành:
### CELL 31: Finetune the model ### trainer.train()
Ô thứ ba mươi hai là ô sổ tay cuối cùng. Nó lưu mô hình đã tinh chỉnh bằng cách gọi phương thức save_model
trên phiên bản Trainer
. Đặt ô thứ ba mươi giây thành:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Bây giờ tất cả các ô của sổ ghi chép đã được tạo xong, đã đến lúc bắt đầu tinh chỉnh.
Đặt Máy tính xách tay Kaggle chạy với bộ tăng tốc NVIDIA GPU P100 .
Cam kết sổ ghi chép trên Kaggle.
Giám sát dữ liệu lần chạy tập luyện bằng cách đăng nhập vào tài khoản WandB của bạn và định vị lần chạy liên quan.
Quá trình đào tạo trên 30 kỷ nguyên sẽ mất khoảng 5 giờ bằng cách sử dụng bộ tăng tốc NVIDIA GPU P100. WER trên dữ liệu loại trừ sẽ giảm xuống ~ 0,15 khi kết thúc khóa đào tạo. Đây không hẳn là một kết quả hiện đại nhưng mô hình đã được tinh chỉnh vẫn đủ hữu ích cho nhiều ứng dụng.
Mô hình đã tinh chỉnh sẽ được xuất ra thư mục Kaggle được chỉ định bởi hằng số OUTPUT_DIR_PATH
được chỉ định trong Bước 3.5 . Đầu ra của mô hình phải bao gồm các tệp sau:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Những tập tin này có thể được tải xuống cục bộ. Ngoài ra, bạn có thể tạo Mô hình Kaggle mới bằng cách sử dụng các tệp mô hình. Mô hình Kaggle sẽ được sử dụng cùng với hướng dẫn suy luận đồng hành để chạy suy luận trên mô hình đã được tinh chỉnh.
Chúc mừng bạn đã hoàn thiện wav2vec2 XLS-R! Hãy nhớ rằng bạn có thể sử dụng các bước chung này để tinh chỉnh mô hình trên các ngôn ngữ khác mà bạn mong muốn. Việc chạy suy luận trên mô hình tinh chỉnh được tạo trong hướng dẫn này khá đơn giản. Các bước suy luận sẽ được trình bày trong hướng dẫn đồng hành riêng cho hướng dẫn này. Vui lòng tìm kiếm tên người dùng HackerNoon của tôi để tìm hướng dẫn đồng hành.