paint-brush
Làm việc với Wav2vec2 Phần 1: Tinh chỉnh XLS-R để nhận dạng giọng nói tự độngtừ tác giả@pictureinthenoise
2,245 lượt đọc
2,245 lượt đọc

Làm việc với Wav2vec2 Phần 1: Tinh chỉnh XLS-R để nhận dạng giọng nói tự động

từ tác giả Picture in the Noise29m2024/05/04
Read on Terminal Reader

dài quá đọc không nổi

Hướng dẫn này giải thích các bước để hoàn thiện mô hình wav2vec2 XLS-R của Meta AI để nhận dạng giọng nói tự động ("ASR"). Hướng dẫn này bao gồm hướng dẫn từng bước về cách xây dựng Sổ tay Kaggle có thể được sử dụng để tinh chỉnh mô hình. Mô hình được đào tạo trên bộ dữ liệu tiếng Tây Ban Nha của Chile.
featured image - Làm việc với Wav2vec2 Phần 1: Tinh chỉnh XLS-R để nhận dạng giọng nói tự động
Picture in the Noise HackerNoon profile picture
0-item
1-item

Giới thiệu

Meta AI đã giới thiệu wav2vec2 XLS-R ("XLS-R") vào cuối năm 2021. XLS-R là mô hình máy học ("ML") để học cách biểu đạt giọng nói đa ngôn ngữ; và nó đã được đào tạo trên hơn 400.000 giờ âm thanh giọng nói có sẵn công khai trên 128 ngôn ngữ. Sau khi phát hành, mô hình này thể hiện một bước nhảy vọt so với mô hình đa ngôn ngữ XLSR-53 của Meta AI, mô hình này đã được đào tạo về khoảng 50.000 giờ âm thanh lời nói trên 53 ngôn ngữ.


Hướng dẫn này giải thích các bước để tinh chỉnh XLS-R để nhận dạng giọng nói tự động ("ASR") bằng Sổ tay Kaggle . Mô hình này sẽ được tinh chỉnh bằng tiếng Tây Ban Nha Chile, nhưng bạn có thể làm theo các bước chung để tinh chỉnh XLS-R trên các ngôn ngữ khác nhau mà bạn mong muốn.


Việc chạy suy luận trên mô hình tinh chỉnh sẽ được mô tả trong phần hướng dẫn đi kèm, khiến hướng dẫn này trở thành phần đầu tiên trong hai phần. Tôi quyết định tạo một hướng dẫn riêng dành riêng cho suy luận vì hướng dẫn tinh chỉnh này hơi dài.


Giả sử bạn đã có nền tảng ML hiện có và hiểu các khái niệm ASR cơ bản. Người mới bắt đầu có thể gặp khó khăn khi theo dõi/hiểu các bước xây dựng.

Một chút thông tin cơ bản về XLS-R

Mô hình wav2vec2 ban đầu được giới thiệu vào năm 2020 đã được đào tạo trước trên 960 giờ âm thanh giọng nói của tập dữ liệu Librispeech và ~53.200 giờ âm thanh giọng nói của tập dữ liệu LibriVox . Sau khi phát hành, có hai kích cỡ mô hình: mô hình BASE với 95 triệu tham số và mô hình LARGE với 317 triệu tham số.


Mặt khác, XLS-R đã được huấn luyện trước về âm thanh lời nói đa ngôn ngữ từ 5 bộ dữ liệu:


  • VoxPopuli : Tổng cộng có ~372.000 giờ âm thanh bài phát biểu trên 23 ngôn ngữ phát biểu nghị viện Châu Âu từ quốc hội Châu Âu.
  • Thư viện đa ngôn ngữ : Tổng cộng ~50.000 giờ âm thanh giọng nói trên tám ngôn ngữ Châu Âu, với phần lớn (~44.000 giờ) dữ liệu âm thanh bằng tiếng Anh.
  • CommonVoice : Tổng cộng ~7.000 giờ âm thanh giọng nói trên 60 ngôn ngữ.
  • VoxLingua107 : Tổng cộng ~6.600 giờ âm thanh giọng nói trên 107 ngôn ngữ dựa trên nội dung YouTube.
  • BABEL : Tổng cộng ~1.100 giờ âm thanh giọng nói trên 17 ngôn ngữ Châu Phi và Châu Á dựa trên lời nói đàm thoại qua điện thoại.


Có 3 mẫu XLS-R: XLS-R (0,3B) với 300 triệu thông số, XLS-R (1B) với 1 tỷ thông số và XLS-R (2B) với 2 tỷ thông số. Hướng dẫn này sẽ sử dụng mẫu XLS-R (0,3B).

Tiếp cận

Có một số bài viết hay về cách tinh chỉnh các mô hình wav2vev2 , có lẽ mô hình này là một loại "tiêu chuẩn vàng". Tất nhiên, cách tiếp cận chung ở đây bắt chước những gì bạn sẽ tìm thấy trong các hướng dẫn khác. Bạn sẽ:


  • Tải tập dữ liệu huấn luyện gồm dữ liệu âm thanh và bản ghi văn bản liên quan.
  • Tạo từ vựng từ bản phiên âm văn bản trong tập dữ liệu.
  • Khởi tạo bộ xử lý wav2vec2 sẽ trích xuất các tính năng từ dữ liệu đầu vào cũng như chuyển đổi bản phiên âm văn bản thành chuỗi nhãn.
  • Tinh chỉnh wav2vec2 XLS-R trên dữ liệu đầu vào đã xử lý.


Tuy nhiên, có ba điểm khác biệt chính giữa hướng dẫn này và các hướng dẫn khác:


  1. Hướng dẫn không cung cấp nhiều cuộc thảo luận "nội tuyến" về các khái niệm ML và ASR có liên quan.
    • Mặc dù mỗi phần phụ trên các ô sổ tay riêng lẻ sẽ bao gồm thông tin chi tiết về việc sử dụng/mục đích của ô cụ thể, nhưng giả định rằng bạn có nền tảng ML hiện có và bạn hiểu các khái niệm ASR cơ bản.
  2. Sổ tay Kaggle mà bạn sẽ xây dựng sẽ sắp xếp các phương thức tiện ích trong các ô cấp cao nhất.
    • Trong khi nhiều sổ ghi chép tinh chỉnh có xu hướng có kiểu bố cục kiểu "dòng ý thức", tôi đã quyết định sắp xếp tất cả các phương pháp hữu ích lại với nhau. Nếu mới làm quen với wav2vec2, bạn có thể thấy cách tiếp cận này khó hiểu. Tuy nhiên, để nhắc lại, tôi cố gắng hết sức để giải thích rõ ràng mục đích của từng ô trong phần phụ dành riêng cho mỗi ô. Nếu bạn mới tìm hiểu về wav2vec2, bạn có thể được hưởng lợi từ việc xem nhanh bài viết HackerNoon wav2vec2 của tôi về Nhận dạng giọng nói tự động bằng tiếng Anh đơn giản .
  3. Hướng dẫn này chỉ mô tả các bước để tinh chỉnh.
    • Như đã đề cập trong phần Giới thiệu , tôi đã chọn tạo một hướng dẫn đồng hành riêng về cách chạy suy luận trên mô hình XLS-R đã được tinh chỉnh mà bạn sẽ tạo. Điều này được thực hiện để tránh hướng dẫn này trở nên quá dài.

Điều kiện tiên quyết và trước khi bạn bắt đầu

Để hoàn thành hướng dẫn, bạn sẽ cần phải có:


  • Một tài khoản Kaggle hiện có. Nếu chưa có tài khoản Kaggle, bạn cần tạo một tài khoản.
  • Tài khoản Trọng số và Xu hướng hiện có ("WandB") . Nếu hiện tại bạn chưa có tài khoản Trọng lượng và Xu hướng, bạn cần tạo một tài khoản.
  • Khóa API WandB. Nếu bạn không có khóa API WandB, hãy làm theo các bước tại đây .
  • Kiến thức trung cấp về Python.
  • Kiến thức trung cấp về làm việc với Kaggle Notebooks.
  • Kiến thức trung cấp về các khái niệm ML.
  • Kiến thức cơ bản về các khái niệm ASR.


Trước khi bắt đầu xây dựng sổ ghi chép, bạn có thể xem lại hai phần phụ ngay bên dưới. Họ mô tả:


  1. Tập dữ liệu huấn luyện.
  2. Số liệu Tỷ lệ Lỗi Từ ("WER") được sử dụng trong quá trình đào tạo.

Tập dữ liệu đào tạo

Như đã đề cập trong phần Giới thiệu , mẫu XLS-R sẽ được tinh chỉnh bằng tiếng Tây Ban Nha ở Chile. Tập dữ liệu cụ thể là Tập dữ liệu lời nói tiếng Tây Ban Nha ở Chile được phát triển bởi Guevara-Rukoz et al. Nó có sẵn để tải xuống trên OpenSLR . Bộ dữ liệu bao gồm hai bộ dữ liệu phụ: (1) 2.636 bản ghi âm của những người nói tiếng Chile là nam và (2) 1.738 bản ghi âm của những người nói tiếng Chile là nữ.


Mỗi tập dữ liệu con bao gồm một tệp chỉ mục line_index.tsv . Mỗi dòng của mỗi tệp chỉ mục chứa một cặp tên tệp âm thanh và bản ghi âm của tệp được liên kết, ví dụ:


 clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente


Tôi đã tải Bộ dữ liệu lời nói tiếng Tây Ban Nha tiếng Chile lên Kaggle để thuận tiện. Có một tập dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nam giới và một bộ dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nữ . Các bộ dữ liệu Kaggle này sẽ được thêm vào Sổ tay Kaggle mà bạn sẽ xây dựng theo các bước trong hướng dẫn này.

Tỷ lệ lỗi từ (WER)

WER là một số liệu có thể được sử dụng để đo lường hiệu suất của các mô hình nhận dạng giọng nói tự động. WER cung cấp một cơ chế để đo lường mức độ gần gũi của dự đoán văn bản với tham chiếu văn bản. WER thực hiện điều này bằng cách ghi lại 3 loại lỗi:


  • thay thế ( S ): Lỗi thay thế được ghi lại khi dự đoán chứa một từ khác với từ tương tự trong tham chiếu. Ví dụ: điều này xảy ra khi dự đoán viết sai chính tả một từ trong tài liệu tham khảo.

  • xóa ( D ): Lỗi xóa được ghi lại khi dự đoán chứa một từ không có trong tham chiếu.

  • phần chèn ( I ): Lỗi chèn được ghi lại khi dự đoán không chứa từ nào có trong tham chiếu.


Rõ ràng, WER hoạt động ở cấp độ từ. Công thức cho số liệu WER như sau:


 WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference


Một ví dụ WER đơn giản bằng tiếng Tây Ban Nha như sau:


 prediction: "Él está saliendo." reference: "Él está saltando."


Một bảng giúp hình dung các lỗi trong dự đoán:

CHỮ

TỪ 1

TỪ 2

TỪ 3

sự dự đoán

Él

está

saliendo

thẩm quyền giải quyết

Él

está

muối


Chính xác

Chính xác

thay thế

Dự đoán có 1 lỗi thay thế, 0 lỗi xóa và 0 lỗi chèn. Vì vậy, WER cho ví dụ này là:


 WER = 1 + 0 + 0 / 3 = 1/3 = 0.33


Rõ ràng là Tỷ lệ lỗi từ không nhất thiết cho chúng ta biết những lỗi cụ thể nào đang tồn tại. Trong ví dụ trên, WER xác định rằng WORD 3 có lỗi trong văn bản được dự đoán, nhưng nó không cho chúng ta biết rằng các ký tự ie sai trong dự đoán. Các số liệu khác, chẳng hạn như Tỷ lệ lỗi ký tự ("CER"), có thể được sử dụng để phân tích lỗi chính xác hơn.

Xây dựng sổ tay tinh chỉnh

Bây giờ bạn đã sẵn sàng để bắt đầu xây dựng sổ ghi chép tinh chỉnh.


  • Bước 1Bước 2 hướng dẫn bạn thiết lập môi trường Kaggle Notebook.
  • Bước 3 hướng dẫn bạn cách xây dựng sổ ghi chép. Nó chứa 32 bước phụ đại diện cho 32 ô của sổ ghi chép tinh chỉnh.
  • Bước 4 hướng dẫn bạn cách chạy sổ ghi chép, theo dõi quá trình đào tạo và lưu mô hình.

Bước 1 - Tìm nạp khóa API WandB của bạn

Sổ tay Kaggle của bạn phải được định cấu hình để gửi dữ liệu chạy đào tạo tới WandB bằng khóa API WandB của bạn. Để làm được điều đó, bạn cần sao chép nó.


  1. Đăng nhập vào WandB tại www.wandb.com .
  2. Điều hướng đến www.wandb.ai/authorize .
  3. Sao chép khóa API của bạn để sử dụng trong bước tiếp theo.

Bước 2 - Thiết lập môi trường Kaggle của bạn

Bước 2.1 - Tạo sổ tay Kaggle mới


  1. Đăng nhập vào Kaggle.
  2. Tạo Sổ tay Kaggle mới.
  3. Tất nhiên, tên sổ ghi chép có thể được thay đổi theo ý muốn. Hướng dẫn này sử dụng tên sổ ghi chép xls-r-300m-chilean-spanish-asr .

Bước 2.2 - Đặt khóa API WandB của bạn

Bí mật Kaggle sẽ được sử dụng để lưu trữ khóa API WandB của bạn một cách an toàn.


  1. Nhấp vào Tiện ích bổ sung trên menu chính của Kaggle Notebook.
  2. Chọn Bí mật từ menu bật lên.
  3. Nhập nhãn WANDB_API_KEY vào trường Nhãn và nhập khóa API WandB của bạn cho giá trị.
  4. Đảm bảo rằng hộp kiểm Đã đính kèm ở bên trái của trường nhãn WANDB_API_KEY được chọn.
  5. Nhấp vào Xong .

Bước 2.3 - Thêm bộ dữ liệu đào tạo

Tập dữ liệu giọng nói tiếng Tây Ban Nha ở Chile đã được tải lên Kaggle dưới dạng 2 tập dữ liệu riêng biệt:


Thêm cả hai bộ dữ liệu này vào Sổ tay Kaggle của bạn.

Bước 3 - Xây dựng sổ tay tinh chỉnh

32 bước phụ sau đây sẽ xây dựng từng ô trong số 32 ô của sổ ghi chép tinh chỉnh theo thứ tự.

Bước 3.1 - Ô 1: Cài đặt gói

Ô đầu tiên của sổ ghi chép tinh chỉnh sẽ cài đặt các phần phụ thuộc. Đặt ô đầu tiên thành:


 ### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer


  • Dòng đầu tiên nâng cấp gói torchaudio lên phiên bản mới nhất. torchaudio sẽ được sử dụng để tải các tệp âm thanh và lấy mẫu lại dữ liệu âm thanh.
  • Dòng thứ hai cài đặt gói jiwer được yêu cầu để sử dụng phương thức load_metric của thư viện HuggingFace Datasets được sử dụng sau này.

Bước 3.2 - Ô 2: Nhập gói Python

Việc nhập ô thứ hai yêu cầu các gói Python. Đặt ô thứ hai thành:


 ### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer


  • Có lẽ bạn đã quen thuộc với hầu hết các gói này. Việc sử dụng chúng trong sổ ghi chép sẽ được giải thích khi các ô tiếp theo được xây dựng.
  • Điều đáng nói là thư viện transformers HuggingFace và các lớp Wav2Vec2* liên quan cung cấp nền tảng cho chức năng được sử dụng để tinh chỉnh.

Bước 3.3 - Ô 3: Đang tải số liệu WER

Ô thứ ba nhập số liệu đánh giá HuggingFace WER. Đặt ô thứ ba thành:


 ### CELL 3: Load WER metric ### wer_metric = load_metric("wer")


  • Như đã đề cập trước đó, WER sẽ được sử dụng để đo lường hiệu suất của mô hình trên dữ liệu đánh giá/nắm giữ.

Bước 3.4 - Ô 4: Đăng nhập vào WandB

Ô thứ tư lấy bí mật WANDB_API_KEY của bạn đã được đặt ở Bước 2.2 . Đặt ô thứ tư thành:


 ### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)


  • Khóa API được sử dụng để định cấu hình Kaggle Notebook để dữ liệu chạy đào tạo được gửi đến WandB.

Bước 3.5 - Ô 5: Đặt hằng số

Ô thứ năm đặt các hằng số sẽ được sử dụng trong toàn bộ sổ ghi chép. Đặt ô thứ năm thành:


 ### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800


  • Sổ ghi chép không hiển thị mọi hằng số có thể tưởng tượng được trong ô này. Một số giá trị có thể được biểu thị bằng hằng số đã được giữ nguyên dòng.
  • Việc sử dụng nhiều hằng số ở trên là hiển nhiên. Đối với những trường hợp không, việc sử dụng chúng sẽ được giải thích trong các bước phụ sau.

Bước 3.6 - Ô 6: Các phương pháp tiện ích để đọc tệp chỉ mục, làm sạch văn bản và tạo từ vựng

Ô thứ sáu xác định các phương thức tiện ích để đọc các tệp chỉ mục tập dữ liệu (xem phần phụ Tập dữ liệu huấn luyện ở trên), cũng như để làm sạch văn bản phiên âm và tạo từ vựng. Đặt ô thứ sáu thành:


 ### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list


  • Phương thức read_index_file_data đọc tệp chỉ mục tập dữ liệu line_index.tsv và tạo danh sách các danh sách có tên tệp âm thanh và dữ liệu phiên âm, ví dụ:


 [ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]


  • Phương thức truncate_training_dataset cắt bớt dữ liệu tệp chỉ mục danh sách bằng cách sử dụng hằng NUM_LOAD_FROM_EACH_SET được đặt ở Bước 3.5 . Cụ thể, hằng số NUM_LOAD_FROM_EACH_SET được sử dụng để chỉ định số lượng mẫu âm thanh cần tải từ mỗi tập dữ liệu. Vì mục đích của hướng dẫn này, con số được đặt ở 1600 , nghĩa là cuối cùng tổng cộng 3200 mẫu âm thanh sẽ được tải. Để tải tất cả các mẫu, hãy đặt NUM_LOAD_FROM_EACH_SET thành giá trị chuỗi all .
  • Phương thức clean_text được sử dụng để loại bỏ từng phiên âm văn bản của các ký tự được chỉ định bởi biểu thức chính quy được gán cho SPECIAL_CHARS trong Bước 3.5 . Những ký tự này, bao gồm cả dấu câu, có thể bị loại bỏ vì chúng không cung cấp bất kỳ giá trị ngữ nghĩa nào khi đào tạo mô hình để tìm hiểu ánh xạ giữa các tính năng âm thanh và bản chép lời văn bản.
  • Phương thức create_vocab tạo từ vựng từ các bản chép lại văn bản rõ ràng. Đơn giản, nó trích xuất tất cả các ký tự duy nhất từ bộ phiên âm văn bản đã được làm sạch. Bạn sẽ thấy ví dụ về từ vựng được tạo ở Bước 3.14 .

Bước 3.7 - CELL 7: Các phương pháp tiện ích để tải và lấy mẫu lại dữ liệu âm thanh

Ô thứ bảy xác định các phương thức tiện ích sử dụng torchaudio để tải và lấy mẫu lại dữ liệu âm thanh. Đặt ô thứ bảy thành:


 ### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]


  • Phương thức read_audio_data tải một tệp âm thanh được chỉ định và trả về một ma trận đa chiều torch.Tensor của dữ liệu âm thanh cùng với tốc độ lấy mẫu của âm thanh. Tất cả các tệp âm thanh trong dữ liệu huấn luyện đều có tốc độ lấy mẫu là 48000 Hz. Tốc độ lấy mẫu "gốc" này được ghi lại bằng hằng số ORIG_SAMPLING_RATE trong Bước 3.5 .
  • Phương pháp resample được sử dụng để giảm mẫu dữ liệu âm thanh từ tốc độ lấy mẫu từ 48000 xuống 16000 . wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở 16000 Hz. Theo đó, bất kỳ âm thanh nào được sử dụng để tinh chỉnh đều phải có cùng tốc độ lấy mẫu. Trong trường hợp này, mẫu âm thanh phải được giảm tần số lấy mẫu từ 48000 Hz xuống 16000 Hz. 16000 Hz được ghi lại bằng hằng số TGT_SAMPLING_RATEBước 3.5 .

Bước 3.8 - Ô 8: Các phương pháp hữu ích để chuẩn bị dữ liệu cho đào tạo

Ô thứ tám xác định các phương thức tiện ích xử lý dữ liệu âm thanh và phiên âm. Đặt ô thứ tám thành:


 ### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding


  • Phương thức process_speech_audio trả về giá trị đầu vào từ mẫu đào tạo được cung cấp.
  • Phương thức process_target_text mã hóa mỗi bản phiên âm văn bản dưới dạng danh sách các nhãn - tức là danh sách các chỉ mục đề cập đến các ký tự trong từ vựng. Bạn sẽ thấy mã hóa mẫu ở Bước 3.15 .

Bước 3.9 - CELL 9: Phương pháp tiện ích để tính tỷ lệ lỗi từ

Ô thứ chín là ô phương thức tiện ích cuối cùng và chứa phương pháp tính Tỷ lệ lỗi từ giữa bản phiên âm tham chiếu và bản phiên âm dự đoán. Đặt ô thứ chín thành:


 ### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}

Bước 3.10 - Ô 10: Đọc dữ liệu huấn luyện

Ô thứ mười đọc các tệp chỉ mục dữ liệu huấn luyện cho bản ghi của người nói nam và bản ghi của người nói nữ bằng phương pháp read_index_file_data được xác định trong Bước 3.6 . Đặt ô thứ mười thành:


 ### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")


  • Như đã thấy, dữ liệu đào tạo được quản lý theo hai danh sách dành riêng cho giới tính tại thời điểm này. Dữ liệu sẽ được kết hợp ở Bước 3.12 sau khi cắt bớt.

Bước 3.11 - Ô 11: Cắt bớt dữ liệu huấn luyện

Ô thứ mười một cắt bớt danh sách dữ liệu huấn luyện bằng phương pháp truncate_training_dataset được xác định trong Bước 3.6 . Đặt ô thứ mười một thành:


 ### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)


  • Xin nhắc lại, hằng số NUM_LOAD_FROM_EACH_SET được đặt ở Bước 3.5 xác định số lượng mẫu cần giữ lại từ mỗi tập dữ liệu. Hằng số được đặt thành 1600 trong hướng dẫn này cho tổng số 3200 mẫu.

Bước 3.12 - Ô 12: Kết hợp dữ liệu mẫu huấn luyện

Ô thứ mười hai kết hợp các danh sách dữ liệu huấn luyện bị cắt bớt. Đặt ô thứ mười hai thành:


 ### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl

Bước 3.13 - Ô 13: Kiểm tra phiên mã sạch

Ô thứ mười ba lặp lại từng mẫu dữ liệu huấn luyện và xóa văn bản phiên mã liên quan bằng phương pháp clean_text được xác định trong Bước 3.6 . Đặt ô thứ mười ba thành:


 for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])

Bước 3.14 - Ô 14: Tạo từ vựng

Ô thứ mười bốn tạo từ vựng bằng cách sử dụng các bản phiên âm đã được làm sạch từ bước trước và phương thức create_vocab được xác định ở Bước 3.6 . Đặt ô thứ mười bốn thành:


 ### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}


  • Từ vựng được lưu trữ dưới dạng từ điển với các ký tự là khóa và chỉ mục từ vựng là giá trị.

  • Bạn có thể in vocab_dict để tạo ra kết quả sau:


 {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}

Bước 3.15 - Ô 15: Thêm dấu phân cách từ vào từ vựng

Ô thứ mười lăm thêm ký tự phân cách từ | đến từ vựng. Đặt ô thứ mười lăm thành:


 ### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)


  • Ký tự phân cách từ được sử dụng khi mã hóa bản phiên âm văn bản dưới dạng danh sách nhãn. Cụ thể, nó được sử dụng để xác định phần cuối của một từ và nó được sử dụng khi khởi tạo lớp Wav2Vec2CTCTokenizer , như sẽ thấy trong Bước 3.17 .

  • Ví dụ: danh sách sau đây mã hóa no te entiendo nada bằng cách sử dụng từ vựng ở Bước 3.14 :


 # Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}


  • Một câu hỏi có thể nảy sinh một cách tự nhiên là: "Tại sao cần xác định ký tự phân cách từ?" Ví dụ: phần cuối của các từ trong văn bản tiếng Anh và tiếng Tây Ban Nha được đánh dấu bằng khoảng trắng nên việc sử dụng ký tự khoảng trắng làm dấu phân cách từ là một vấn đề đơn giản. Hãy nhớ rằng tiếng Anh và tiếng Tây Ban Nha chỉ là hai ngôn ngữ trong số hàng nghìn ngôn ngữ; và không phải tất cả các ngôn ngữ viết đều sử dụng dấu cách để đánh dấu ranh giới từ.

Bước 3.16 - Ô 16: Xuất từ vựng

Ô thứ mười sáu chuyển từ vựng vào một tập tin. Đặt ô thứ mười sáu thành:


 ### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)


  • Tệp từ vựng sẽ được sử dụng trong bước tiếp theo, Bước 3.17 , để khởi tạo lớp Wav2Vec2CTCTokenizer .

Bước 3.17 - CELL 17: Khởi tạo Tokenizer

Ô thứ mười bảy khởi tạo một phiên bản của Wav2Vec2CTCTokenizer . Đặt ô thứ mười bảy thành:


 ### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )


  • Trình mã thông báo được sử dụng để mã hóa bản phiên âm văn bản và giải mã danh sách nhãn trở lại văn bản.

  • Lưu ý rằng tokenizer được khởi tạo với [UNK] được gán cho unk_token[PAD] được gán cho pad_token , với mã trước đây được sử dụng để biểu thị các mã thông báo không xác định trong bản phiên âm văn bản và mã thông báo sau được sử dụng để đệm phiên âm khi tạo các lô phiên âm có độ dài khác nhau. Hai giá trị này sẽ được thêm vào từ vựng bằng mã thông báo.

  • Việc khởi tạo mã thông báo trong bước này cũng sẽ thêm hai mã thông báo bổ sung vào từ vựng, đó là <s>/</s> , được sử dụng để phân định lần lượt phần đầu và phần cuối của câu.

  • | được gán cho word_delimiter_token một cách rõ ràng trong bước này để phản ánh rằng ký hiệu ống sẽ được sử dụng để phân định ranh giới cuối từ theo cách chúng ta thêm ký tự vào từ vựng trong Bước 3.15 . | ký hiệu là giá trị mặc định cho word_delimiter_token . Vì vậy, nó không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.

  • Tương tự như với word_delimiter_token , một khoảng trắng được gán rõ ràng cho replace_word_delimiter_char phản ánh rằng ký hiệu ống | sẽ được sử dụng để thay thế các ký tự khoảng trống trong phiên âm văn bản. Khoảng trống là giá trị mặc định cho replace_word_delimiter_char . Vì vậy, nó cũng không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.

  • Bạn có thể in toàn bộ từ vựng của tokenizer bằng cách gọi phương thức get_vocab() trên tokenizer .


 vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}

Bước 3.18 - Ô 18: Khởi tạo Trình trích xuất tính năng

Ô thứ mười tám khởi tạo một phiên bản của Wav2Vec2FeatureExtractor . Đặt ô thứ mười tám thành:


 ### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )


  • Trình trích xuất tính năng được sử dụng để trích xuất các tính năng từ dữ liệu đầu vào, tất nhiên là dữ liệu âm thanh trong trường hợp sử dụng này. Bạn sẽ tải dữ liệu âm thanh cho từng mẫu dữ liệu huấn luyện ở Bước 3.20 .
  • Các giá trị tham số được truyền tới trình khởi tạo Wav2Vec2FeatureExtractor đều là các giá trị mặc định, ngoại trừ return_attention_mask được mặc định là False . Các giá trị mặc định được hiển thị/chuyển đi nhằm mục đích rõ ràng.
  • Tham số feature_size chỉ định kích thước kích thước của các tính năng đầu vào (tức là tính năng dữ liệu âm thanh). Giá trị mặc định của tham số này là 1 .
  • sampling_rate cho trình trích xuất tính năng biết tốc độ lấy mẫu mà tại đó dữ liệu âm thanh sẽ được số hóa. Như đã thảo luận ở Bước 3.7 , wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở 16000 Hz và do đó 16000 là giá trị mặc định cho tham số này.
  • Tham số padding_value chỉ định giá trị được sử dụng khi đệm dữ liệu âm thanh, theo yêu cầu khi phân nhóm các mẫu âm thanh có độ dài khác nhau. Giá trị mặc định là 0.0 .
  • do_normalize được sử dụng để chỉ định xem dữ liệu đầu vào có nên được chuyển đổi sang phân phối chuẩn chuẩn hay không. Giá trị mặc định là True . Tài liệu lớp Wav2Vec2FeatureExtractor lưu ý rằng "[chuẩn hóa] có thể giúp cải thiện đáng kể hiệu suất cho một số kiểu máy."
  • Các tham số return_attention_mask chỉ định xem mặt nạ chú ý có được chuyển hay không. Giá trị được đặt thành True cho trường hợp sử dụng này.

Bước 3.19 - CELL 19: Khởi tạo bộ xử lý

Ô thứ mười chín khởi tạo một phiên bản của Wav2Vec2Processor . Đặt ô thứ mười chín thành:


 ### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)


  • Lớp Wav2Vec2Processor kết hợp tokenizerfeature_extractor từ Bước 3.17Bước 3.18 tương ứng vào một bộ xử lý duy nhất.

  • Lưu ý rằng cấu hình bộ xử lý có thể được lưu bằng cách gọi phương thức save_pretrained trên phiên bản lớp Wav2Vec2Processor .


 processor.save_pretrained(OUTPUT_DIR_PATH)

Bước 3.20 - CELL 20: Đang tải dữ liệu âm thanh

Ô thứ 20 tải từng tệp âm thanh được chỉ định trong danh sách all_training_samples . Đặt ô thứ hai mươi thành:


 ### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })


  • Dữ liệu âm thanh được trả về dưới dạng torch.Tensor và được lưu trữ trong all_input_data dưới dạng danh sách từ điển. Mỗi từ điển chứa dữ liệu âm thanh cho một mẫu cụ thể, cùng với bản phiên âm văn bản của âm thanh.
  • Lưu ý rằng phương thức read_audio_data cũng trả về tốc độ lấy mẫu của dữ liệu âm thanh. Vì chúng ta biết rằng tốc độ lấy mẫu là 48000 Hz cho tất cả các tệp âm thanh trong trường hợp sử dụng này nên tốc độ lấy mẫu sẽ bị bỏ qua trong bước này.

Bước 3.21 - CELL 21: Chuyển đổi all_input_data thành Pandas DataFrame

Ô thứ 21 chuyển đổi danh sách all_input_data thành Pandas DataFrame để giúp thao tác dữ liệu dễ dàng hơn. Đặt ô thứ 21 thành:


 ### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)

Bước 3.22 - CELL 22: Xử lý dữ liệu âm thanh và phiên âm văn bản

Ô thứ 22 sử dụng processor được khởi tạo ở Bước 3.19 để trích xuất các tính năng từ từng mẫu dữ liệu âm thanh và mã hóa từng bản phiên âm văn bản dưới dạng danh sách các nhãn. Đặt ô thứ hai mươi hai thành:


 ### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))

Bước 3.23 - CELL 23: Tách dữ liệu đầu vào thành tập dữ liệu huấn luyện và xác thực

Ô thứ hai mươi ba chia DataFrame all_input_data_df thành các tập dữ liệu huấn luyện và đánh giá (xác thực) bằng cách sử dụng hằng số SPLIT_PCT từ Bước 3.5 . Đặt ô thứ hai mươi ba thành:


 ### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]


  • Giá trị SPLIT_PCT0.10 trong hướng dẫn này, nghĩa là 10% tất cả dữ liệu đầu vào sẽ được giữ lại để đánh giá và 90% dữ liệu sẽ được sử dụng để đào tạo/tinh chỉnh.
  • Vì có tổng cộng 3.200 mẫu huấn luyện nên 320 mẫu sẽ được sử dụng để đánh giá và 2.880 mẫu còn lại được sử dụng để tinh chỉnh mô hình.

Bước 3.24 - CELL 24: Chuyển đổi tập dữ liệu huấn luyện và xác thực thành đối tượng Dataset

Ô thứ 24 chuyển đổi DataFrames train_data_dfvalid_data_df thành các đối tượng Dataset . Đặt ô thứ 24 thành:


 ### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)


  • Các đối tượng Dataset được sử dụng bởi các phiên bản của lớp HuggingFace Trainer , như bạn sẽ thấy trong Bước 3.30 .

  • Các đối tượng này chứa siêu dữ liệu về tập dữ liệu cũng như chính tập dữ liệu đó.

  • Bạn có thể in train_datavalid_data để xem siêu dữ liệu cho cả hai đối tượng Dataset .


 print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })

Bước 3.25 - CELL 25: Khởi tạo mô hình tiền huấn luyện

Ô thứ 25 khởi tạo mô hình XLS-R (0,3) đã được huấn luyện trước. Đặt ô thứ 25 thành:


 ### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )


  • Phương thức from_pretrained được gọi trên Wav2Vec2ForCTC chỉ định rằng chúng ta muốn tải các trọng số đã được huấn luyện trước cho mô hình đã chỉ định.
  • Hằng số MODEL được chỉ định ở Bước 3.5 và được đặt thành facebook/wav2vec2-xls-r-300m phản ánh mô hình XLS-R (0.3).
  • Tham số ctc_loss_reduction chỉ định loại giảm để áp dụng cho đầu ra của hàm mất mát Phân loại thời gian kết nối ("CTC"). Suy hao CTC được sử dụng để tính toán suy hao giữa đầu vào liên tục, trong trường hợp này là dữ liệu âm thanh và chuỗi đích, trong trường hợp này là bản chép lại văn bản. Bằng cách đặt giá trị mean , tổn thất đầu ra của một lô đầu vào sẽ được chia cho độ dài mục tiêu. Sau đó, giá trị trung bình của lô sẽ được tính toán và mức giảm được áp dụng cho các giá trị tổn thất.
  • pad_token_id chỉ định mã thông báo được sử dụng để đệm khi tạo khối. Nó được đặt thành id [PAD] được đặt khi khởi tạo mã thông báo ở Bước 3.17 .
  • Tham số vocab_size xác định kích thước từ vựng của mô hình. Đó là kích thước từ vựng sau khi khởi tạo mã thông báo ở Bước 3.17 và phản ánh số lượng nút lớp đầu ra của phần chuyển tiếp của mạng.

Bước 3.26 - CELL 26: Tính năng đóng băng Trọng lượng trích xuất

Ô thứ 26 đóng băng các trọng số đã được huấn luyện trước của bộ trích xuất đặc điểm. Đặt ô thứ hai mươi sáu thành:


 ### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()

Bước 3.27 - Ô 27: Thiết lập đối số huấn luyện

Ô thứ 27 khởi tạo các đối số huấn luyện sẽ được chuyển đến phiên bản Trainer . Đặt ô thứ hai mươi bảy thành:


 ### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )


  • Lớp TrainingArguments chấp nhận hơn 100 tham số .
  • Tham số save_safetensors khi False chỉ định rằng mô hình đã tinh chỉnh sẽ được lưu vào tệp pickle thay vì sử dụng định dạng safetensors .
  • Tham số group_by_length khi True cho biết rằng các mẫu có độ dài xấp xỉ nhau sẽ được nhóm lại với nhau. Điều này giảm thiểu phần đệm và cải thiện hiệu quả đào tạo.
  • per_device_train_batch_size đặt số lượng mẫu cho mỗi đợt đào tạo nhỏ. Tham số này được đặt thành 18 thông qua hằng số TRAIN_BATCH_SIZE được chỉ định ở Bước 3.5 . Điều này ngụ ý 160 bước mỗi kỷ nguyên.
  • per_device_eval_batch_size đặt số lượng mẫu cho mỗi lô nhỏ đánh giá (tạm giữ). Tham số này được đặt thành 10 thông qua hằng số EVAL_BATCH_SIZE được chỉ định ở Bước 3.5 .
  • num_train_epochs đặt số lượng kỷ nguyên đào tạo. Tham số này được đặt thành 30 thông qua hằng số TRAIN_EPOCHS được chỉ định ở Bước 3.5 . Điều này ngụ ý tổng số 4.800 bước trong quá trình đào tạo.
  • Tham số gradient_checkpointing khi True giúp tiết kiệm bộ nhớ bằng cách kiểm tra các phép tính gradient nhưng dẫn đến tốc độ lùi lại chậm hơn.
  • Tham số evaluation_strategy khi được đặt thành steps có nghĩa là việc đánh giá sẽ được thực hiện và ghi lại trong quá trình đào tạo ở khoảng thời gian được chỉ định bởi tham số eval_steps .
  • Tham logging_strategy khi được đặt thành steps có nghĩa là số liệu thống kê về lần chạy huấn luyện sẽ được ghi lại theo khoảng thời gian được chỉ định bởi tham logging_steps .
  • Tham số save_strategy khi được đặt thành steps có nghĩa là điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu trong khoảng thời gian được chỉ định bởi tham số save_steps .
  • eval_steps đặt số bước giữa các lần đánh giá dữ liệu loại trừ. Tham số này được đặt thành 100 thông qua hằng số EVAL_STEPS được gán ở Bước 3.5 .
  • save_steps đặt số bước sau đó điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu. Tham số này được đặt thành 3200 thông qua hằng số SAVE_STEPS được gán ở Bước 3.5 .
  • logging_steps đặt số bước giữa các nhật ký thống kê về lần chạy tập luyện. Tham số này được đặt thành 100 thông qua hằng số LOGGING_STEPS được gán ở Bước 3.5 .
  • Tham số learning_rate đặt tốc độ học ban đầu. Tham số này được đặt thành 1e-4 thông qua hằng số LEARNING_RATE được gán ở Bước 3.5 .
  • Tham số warmup_steps đặt số bước để tăng tốc độ học tập một cách tuyến tính từ 0 đến giá trị do learning_rate đặt. Tham số này được đặt thành 800 thông qua hằng số WARMUP_STEPS được gán ở Bước 3.5 .

Bước 3.28 - Ô 28: Xác định logic của bộ đối chiếu dữ liệu

Ô thứ hai mươi tám xác định logic cho các chuỗi mục tiêu và đầu vào đệm động. Đặt ô thứ hai mươi tám thành:


 ### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch


  • Các cặp nhãn đầu vào đào tạo và đánh giá được chuyển theo từng đợt nhỏ tới phiên bản Trainer . Các cặp này sẽ được khởi tạo trong giây lát ở Bước 3.30 . Do các chuỗi đầu vào và chuỗi nhãn có độ dài khác nhau trong mỗi lô nhỏ nên một số chuỗi phải được đệm để chúng có cùng độ dài.
  • Lớp DataCollatorCTCWithPadding tự động đệm dữ liệu theo lô nhỏ. Thông số padding khi được đặt thành True chỉ định rằng chuỗi tính năng đầu vào âm thanh ngắn hơn và chuỗi nhãn phải có cùng độ dài với chuỗi dài nhất trong một lô nhỏ.
  • Các tính năng đầu vào âm thanh được đệm bằng giá trị 0.0 được đặt khi khởi chạy trình trích xuất tính năng ở Bước 3.18 .
  • Đầu vào nhãn trước tiên được đệm bằng giá trị đệm được đặt khi khởi tạo bộ mã thông báo ở Bước 3.17 . Các giá trị này được thay thế bằng -100 để các nhãn này bị bỏ qua khi tính chỉ số WER.

Bước 3.29 - CELL 29: Khởi tạo Instance của Data Collator

Ô thứ 29 khởi tạo một phiên bản của bộ đối chiếu dữ liệu được xác định ở bước trước. Đặt ô thứ hai mươi chín thành:


 ### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)

Bước 3.30 - Ô 30: Khởi tạo Trainer

Ô thứ ba mươi khởi tạo một thể hiện của lớp Trainer . Đặt ô thứ ba mươi thành:


 ### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )


  • Như đã thấy, lớp Trainer được khởi tạo bằng:
    • model được huấn luyện trước được khởi tạo ở Bước 3.25 .
    • Trình đối chiếu dữ liệu được khởi tạo ở Bước 3.29 .
    • Các đối số huấn luyện được khởi tạo ở Bước 3.27 .
    • Phương pháp đánh giá WER được xác định ở Bước 3.9 .
    • Đối tượng Dataset train_data từ Bước 3.24 .
    • Đối tượng Dataset valid_data từ Bước 3.24 .
  • Tham số tokenizer được gán cho processor.feature_extractor và hoạt động với data_collator để tự động đệm đầu vào vào đầu vào có độ dài tối đa của mỗi lô nhỏ.

Bước 3.31 - Ô 31: Tinh chỉnh mô hình

Ô thứ 31 gọi phương thức train trên phiên bản lớp Trainer để tinh chỉnh mô hình. Đặt ô thứ ba mươi mốt thành:


 ### CELL 31: Finetune the model ### trainer.train()

Bước 3.32 - Ô 32: Lưu mô hình đã tinh chỉnh

Ô thứ ba mươi hai là ô sổ tay cuối cùng. Nó lưu mô hình đã tinh chỉnh bằng cách gọi phương thức save_model trên phiên bản Trainer . Đặt ô thứ ba mươi giây thành:


 ### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)

Bước 4 - Đào tạo và lưu mô hình

Bước 4.1 - Đào tạo mô hình

Bây giờ tất cả các ô của sổ ghi chép đã được tạo xong, đã đến lúc bắt đầu tinh chỉnh.


  1. Đặt Máy tính xách tay Kaggle chạy với bộ tăng tốc NVIDIA GPU P100 .

  2. Cam kết sổ ghi chép trên Kaggle.

  3. Giám sát dữ liệu lần chạy tập luyện bằng cách đăng nhập vào tài khoản WandB của bạn và định vị lần chạy liên quan.


Quá trình đào tạo trên 30 kỷ nguyên sẽ mất khoảng 5 giờ bằng cách sử dụng bộ tăng tốc NVIDIA GPU P100. WER trên dữ liệu loại trừ sẽ giảm xuống ~ 0,15 khi kết thúc khóa đào tạo. Đây không hẳn là một kết quả hiện đại nhưng mô hình đã được tinh chỉnh vẫn đủ hữu ích cho nhiều ứng dụng.

Bước 4.2 - Lưu mô hình

Mô hình đã tinh chỉnh sẽ được xuất ra thư mục Kaggle được chỉ định bởi hằng số OUTPUT_DIR_PATH được chỉ định trong Bước 3.5 . Đầu ra của mô hình phải bao gồm các tệp sau:


 pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin


Những tập tin này có thể được tải xuống cục bộ. Ngoài ra, bạn có thể tạo Mô hình Kaggle mới bằng cách sử dụng các tệp mô hình. Mô hình Kaggle sẽ được sử dụng cùng với hướng dẫn suy luận đồng hành để chạy suy luận trên mô hình đã được tinh chỉnh.


  1. Đăng nhập vào tài khoản Kaggle của bạn. Bấm vào Mô hình > Mô hình mới .
  2. Thêm tiêu đề cho mô hình đã tinh chỉnh của bạn trong trường Tiêu đề mô hình .
  3. Bấm vào Tạo mô hình .
  4. Bấm vào Đi tới trang chi tiết mô hình .
  5. Nhấp vào Thêm biến thể mới trong Biến thể mẫu .
  6. Chọn Transformers từ menu chọn Framework .
  7. Nhấp vào Thêm biến thể mới .
  8. Kéo và thả các tệp mô hình đã tinh chỉnh của bạn vào cửa sổ Tải lên dữ liệu . Ngoài ra, hãy nhấp vào nút Duyệt tệp để mở cửa sổ trình khám phá tệp và chọn các tệp mô hình đã được tinh chỉnh của bạn.
  9. Khi các file đã được tải lên Kaggle, hãy nhấp vào Create để tạo Kaggle Model .

Phần kết luận

Chúc mừng bạn đã hoàn thiện wav2vec2 XLS-R! Hãy nhớ rằng bạn có thể sử dụng các bước chung này để tinh chỉnh mô hình trên các ngôn ngữ khác mà bạn mong muốn. Việc chạy suy luận trên mô hình tinh chỉnh được tạo trong hướng dẫn này khá đơn giản. Các bước suy luận sẽ được trình bày trong hướng dẫn đồng hành riêng cho hướng dẫn này. Vui lòng tìm kiếm tên người dùng HackerNoon của tôi để tìm hướng dẫn đồng hành.