paint-brush
Nghiên cứu điển hình về phân loại văn bản học máy với khuynh hướng hướng đến sản phẩmtừ tác giả@bemorelavender
29,341 lượt đọc
29,341 lượt đọc

Nghiên cứu điển hình về phân loại văn bản học máy với khuynh hướng hướng đến sản phẩm

từ tác giả Maria K17m2024/03/12
Read on Terminal Reader

dài quá đọc không nổi

Đây là một nghiên cứu điển hình về học máy với khuynh hướng hướng đến sản phẩm: chúng ta sẽ giả vờ rằng chúng ta có một sản phẩm thực tế mà chúng ta cần cải thiện. Chúng ta sẽ khám phá một tập dữ liệu và thử các mô hình khác nhau như hồi quy logistic, mạng thần kinh tái phát và máy biến áp, xem mức độ chính xác của chúng, cách chúng sẽ cải thiện sản phẩm, tốc độ hoạt động của chúng và liệu chúng có dễ gỡ lỗi hay không và mở rộng quy mô.
featured image - Nghiên cứu điển hình về phân loại văn bản học máy với khuynh hướng hướng đến sản phẩm
Maria K HackerNoon profile picture


Chúng ta sẽ giả vờ như chúng ta có một sản phẩm thực sự cần cải tiến. Chúng ta sẽ khám phá một tập dữ liệu và thử các mô hình khác nhau như hồi quy logistic, mạng thần kinh tái phát và máy biến áp, xem mức độ chính xác của chúng, cách chúng sẽ cải thiện sản phẩm, tốc độ hoạt động của chúng và liệu chúng có dễ gỡ lỗi hay không và mở rộng quy mô.


Bạn có thể đọc mã nghiên cứu điển hình đầy đủ trên GitHub và xem sổ ghi chép phân tích với các biểu đồ tương tác trong Jupyter Notebook Viewer .


Hào hứng? Chúng ta hãy đi đến đó!

Cài đặt tác vụ

Hãy tưởng tượng chúng ta sở hữu một trang web thương mại điện tử. Trên trang web này, người bán có thể tải lên mô tả về mặt hàng họ muốn bán. Họ cũng phải chọn danh mục vật phẩm theo cách thủ công, điều này có thể làm chúng chậm lại.


Nhiệm vụ của chúng tôi là tự động hóa việc lựa chọn danh mục dựa trên mô tả mặt hàng. Tuy nhiên, một lựa chọn tự động hóa sai còn tệ hơn là không tự động hóa, bởi vì một sai sót có thể không được chú ý, điều này có thể dẫn đến tổn thất về doanh số bán hàng. Vì vậy, chúng tôi có thể chọn không đặt nhãn tự động nếu không chắc chắn.


Đối với trường hợp nghiên cứu này, chúng tôi sẽ sử dụng Bộ dữ liệu văn bản thương mại điện tử Zenodo , chứa mô tả và danh mục mục.


Tốt hay xấu? Cách chọn mẫu tốt nhất

Chúng tôi sẽ xem xét nhiều kiến trúc mô hình bên dưới và việc quyết định cách chọn tùy chọn tốt nhất trước khi bắt đầu luôn là một phương pháp hay. Mô hình này sẽ tác động đến sản phẩm của chúng ta như thế nào? …cơ sở hạ tầng của chúng tôi?


Rõ ràng, chúng ta sẽ có thước đo chất lượng kỹ thuật để so sánh các mô hình khác nhau nhé. Trong trường hợp này, chúng tôi có nhiệm vụ phân loại nhiều lớp, vì vậy hãy sử dụng điểm chính xác cân bằng để xử lý tốt các nhãn không cân bằng.


Tất nhiên, giai đoạn cuối cùng điển hình của việc kiểm tra ứng viên là kiểm tra AB - giai đoạn trực tuyến, mang lại bức tranh rõ hơn về mức độ ảnh hưởng của sự thay đổi đối với khách hàng. Thông thường, kiểm tra AB tốn nhiều thời gian hơn so với kiểm tra ngoại tuyến, do đó chỉ những ứng viên tốt nhất từ giai đoạn ngoại tuyến mới được kiểm tra. Đây là một nghiên cứu điển hình và chúng tôi không có người dùng thực tế nên chúng tôi sẽ không đề cập đến thử nghiệm AB.


Chúng ta nên cân nhắc điều gì khác trước khi chuyển ứng viên sang thử nghiệm AB? Chúng ta có thể nghĩ gì trong giai đoạn ngoại tuyến để tiết kiệm thời gian thử nghiệm trực tuyến và đảm bảo rằng chúng ta đang thực sự thử nghiệm giải pháp tốt nhất có thể?


Biến các số liệu kỹ thuật thành các số liệu định hướng tác động

Độ chính xác cân bằng là rất tốt, nhưng điểm này không trả lời được câu hỏi “Mô hình này sẽ tác động đến sản phẩm chính xác như thế nào?”. Để tìm thêm điểm định hướng sản phẩm, chúng ta phải hiểu cách chúng ta sẽ sử dụng mô hình.


Trong bối cảnh của chúng tôi, mắc lỗi còn tệ hơn là không đưa ra câu trả lời, vì người bán sẽ phải nhận ra lỗi và thay đổi danh mục theo cách thủ công. Một sai sót không được chú ý sẽ làm giảm doanh số bán hàng và khiến trải nghiệm người dùng của người bán trở nên tồi tệ hơn, chúng ta có nguy cơ mất khách hàng.


Để tránh điều đó, chúng tôi sẽ chọn ngưỡng cho điểm của mô hình để chỉ cho phép mình mắc 1% lỗi. Sau đó, số liệu định hướng sản phẩm có thể được đặt như sau:


Chúng tôi có thể tự động phân loại bao nhiêu phần trăm các mặt hàng nếu khả năng chịu lỗi của chúng tôi chỉ là 1%?


Chúng tôi sẽ gọi đây là Automatic categorisation percentage bên dưới khi chọn mô hình tốt nhất. Tìm mã lựa chọn ngưỡng đầy đủ tại đây .


Thời gian suy luận

Một mô hình mất bao lâu để xử lý một yêu cầu?


Điều này gần như sẽ cho phép chúng tôi so sánh lượng tài nguyên mà chúng tôi sẽ phải duy trì cho một dịch vụ để xử lý tải tác vụ nếu một mô hình được chọn thay vì một mô hình khác.


Khả năng mở rộng

Khi sản phẩm của chúng tôi sắp phát triển, việc quản lý sự tăng trưởng bằng cách sử dụng kiến trúc nhất định sẽ dễ dàng đến mức nào?


Khi nói đến tăng trưởng, chúng ta có thể hiểu là:

  • nhiều danh mục hơn, mức độ chi tiết cao hơn của danh mục
  • mô tả dài hơn
  • tập dữ liệu lớn hơn
  • vân vân

Liệu chúng ta có phải suy nghĩ lại về việc lựa chọn mô hình để xử lý sự tăng trưởng hay chỉ cần đào tạo lại đơn giản là đủ?


Khả năng giải thích

Việc gỡ lỗi của mô hình trong khi đào tạo và sau khi triển khai sẽ dễ dàng như thế nào?


Kích thước mô hình

Kích thước mô hình quan trọng nếu:

  • chúng tôi muốn mô hình của chúng tôi được đánh giá ở phía khách hàng
  • nó lớn đến mức không thể nhét vừa RAM


Sau này chúng ta sẽ thấy rằng cả hai mục trên đều không liên quan, nhưng vẫn đáng để xem xét ngắn gọn.

Khám phá và làm sạch tập dữ liệu

Chúng ta đang làm việc với cái gì? Hãy nhìn vào dữ liệu và xem liệu nó có cần được dọn dẹp không!


Tập dữ liệu chứa 2 cột: mô tả mục và danh mục, tổng cộng 50,5 nghìn hàng.

 file_name = "ecommerceDataset.csv" data = pd.read_csv(file_name, header=None) data.columns = ["category", "description"] print("Rows, cols:", data.shape) # >>> Rows, cols: (50425, 2)


Mỗi mặt hàng được gán 1 trong 4 danh mục có sẵn: Household , Books , Electronics hoặc Clothing & Accessories . Dưới đây là 1 ví dụ về mô tả mục cho mỗi danh mục:


  • Hộ gia đình SPK Trang trí nội thất Mặt nạ đất sét thủ công (Nhiều màu, H35xW12cm) Làm cho ngôi nhà của bạn đẹp hơn với sản phẩm treo tường Mặt nạ đất nung Ấn Độ thủ công này, chưa bao giờ bạn không thể bắt gặp món đồ thủ công này trên thị trường. Bạn có thể thêm nó vào phòng khách/Sảnh vào của bạn.


  • Sách BEGF101/FEG1-Khóa học cơ bản bằng tiếng Anh-1 (Ấn bản Neeraj Publications 2018) BEGF101/FEG1-Khóa học cơ bản bằng tiếng Anh-1


  • Quần áo & Phụ kiện Áo dungaree denim dành cho nữ của Broadstar Nhận được thẻ toàn quyền truy cập khi mặc áo dungaree của Broadstar. Được làm bằng vải denim, những chiếc quần lửng này sẽ giúp bạn luôn thoải mái. Hãy kết hợp chúng với áo sơ mi màu trắng hoặc đen để hoàn thiện vẻ ngoài thường ngày của bạn.


  • Điện Tử Caprigo Heavy Duty - Chân Đế Máy Chiếu Cao Cấp Gắn Trần 2 Feet (Có Thể Điều Chỉnh - Trắng - Trọng Lượng Chịu Tải 15 Kgs)


Giá trị bị mất

Chỉ có một giá trị trống trong tập dữ liệu mà chúng tôi sẽ xóa.

 print(data.info()) # <class 'pandas.core.frame.DataFrame'> # RangeIndex: 50425 entries, 0 to 50424 # Data columns (total 2 columns): # # Column Non-Null Count Dtype # --- ------ -------------- ----- # 0 category 50425 non-null object # 1 description 50424 non-null object # dtypes: object(2) # memory usage: 788.0+ KB data.dropna(inplace=True)


trùng lặp

Tuy nhiên có khá nhiều mô tả trùng lặp. May mắn thay, tất cả các bản sao đều thuộc một danh mục nên chúng ta có thể loại bỏ chúng một cách an toàn.

 repeated_messages = data \ .groupby("description", as_index=False) \ .agg( n_repeats=("category", "count"), n_unique_categories=("category", lambda x: len(np.unique(x))) ) repeated_messages = repeated_messages[repeated_messages["n_repeats"] > 1] print(f"Count of repeated messages (unique): {repeated_messages.shape[0]}") print(f"Total number: {repeated_messages['n_repeats'].sum()} out of {data.shape[0]}") # >>> Count of repeated messages (unique): 13979 # >>> Total number: 36601 out of 50424


Sau khi loại bỏ các bản sao, chúng tôi còn lại 55% tập dữ liệu ban đầu. Bộ dữ liệu được cân bằng tốt.

 data.drop_duplicates(inplace=True) print(f"New dataset size: {data.shape}") print(data["category"].value_counts()) # New dataset size: (27802, 2) # Household 10564 # Books 6256 # Clothing & Accessories 5674 # Electronics 5308 # Name: category, dtype: int64


Ngôn ngữ mô tả

Lưu ý rằng theo mô tả tập dữ liệu,

Bộ dữ liệu đã được lấy từ nền tảng thương mại điện tử Ấn Độ.


Các mô tả không nhất thiết phải được viết bằng tiếng Anh. Một số trong số chúng được viết bằng tiếng Hindi hoặc các ngôn ngữ khác sử dụng các ký hiệu không phải ASCII hoặc được chuyển ngữ sang bảng chữ cái Latinh hoặc sử dụng kết hợp nhiều ngôn ngữ. Ví dụ từ danh mục Books :


  • यू जी सी – नेट जूनियर रिसर्च फैलोशिप एवं सहायक प्रोफेसर योग्यता …
  • Prarambhik Bhartiy Itihas
  • History of NORTH INDIA/வட இந்திய வரலாறு/ …


Để đánh giá sự hiện diện của các từ không phải tiếng Anh trong phần mô tả, hãy tính 2 điểm:


  • Điểm ASCII: tỷ lệ phần trăm ký hiệu không phải ASCII trong mô tả
  • Điểm số từ tiếng Anh hợp lệ: nếu chúng ta chỉ xem xét các chữ cái Latinh, thì bao nhiêu phần trăm từ trong mô tả hợp lệ bằng tiếng Anh? Giả sử rằng các từ tiếng Anh hợp lệ là những từ có trong Word2Vec-300 được đào tạo trên kho ngữ liệu tiếng Anh.


Sử dụng điểm ASCII, chúng tôi biết rằng chỉ 2,3% mô tả bao gồm hơn 1% ký hiệu không phải ASCII.

 def get_ascii_score(description): total_sym_cnt = 0 ascii_sym_cnt = 0 for sym in description: total_sym_cnt += 1 if sym.isascii(): ascii_sym_cnt += 1 return ascii_sym_cnt / total_sym_cnt data["ascii_score"] = data["description"].apply(get_ascii_score) data[data["ascii_score"] < 0.99].shape[0] / data.shape[0] # >>> 0.023


Điểm số từ tiếng Anh hợp lệ cho thấy chỉ 1,5% mô tả có ít hơn 70% từ tiếng Anh hợp lệ trong số các từ ASCII.

 w2v_eng = gensim.models.KeyedVectors.load_word2vec_format(w2v_path, binary=True) def get_valid_eng_score(description): description = re.sub("[^az \t]+", " ", description.lower()) total_word_cnt = 0 eng_word_cnt = 0 for word in description.split(): total_word_cnt += 1 if word.lower() in w2v_eng: eng_word_cnt += 1 return eng_word_cnt / total_word_cnt data["eng_score"] = data["description"].apply(get_valid_eng_score) data[data["eng_score"] < 0.7].shape[0] / data.shape[0] # >>> 0.015


Do đó phần lớn các mô tả (khoảng 96%) đều bằng tiếng Anh hoặc chủ yếu bằng tiếng Anh. Chúng ta có thể xóa tất cả các mô tả khác nhưng thay vào đó, hãy để nguyên chúng và sau đó xem cách mỗi mô hình xử lý chúng.

Làm người mẫu

Hãy chia tập dữ liệu của chúng tôi thành 3 nhóm:

  • Đào tạo 70% - để đào tạo người mẫu (tin nhắn 19k)

  • Kiểm tra 15% - để chọn tham số và ngưỡng (tin nhắn 4,1k)

  • Đánh giá 15% - để chọn mô hình cuối cùng (tin nhắn 4.1k)


 from sklearn.model_selection import train_test_split data_train, data_test = train_test_split(data, test_size=0.3) data_test, data_eval = train_test_split(data_test, test_size=0.5) data_train.shape, data_test.shape, data_eval.shape # >>> ((19461, 3), (4170, 3), (4171, 3))


Mô hình cơ sở: túi từ + hồi quy logistic

Sẽ rất hữu ích nếu bạn làm điều gì đó đơn giản và tầm thường lúc đầu để có được cơ sở tốt. Để làm cơ sở, hãy tạo một túi cấu trúc từ dựa trên tập dữ liệu huấn luyện.


Chúng ta cũng hãy giới hạn kích thước từ điển ở mức 100 từ.

 count_vectorizer = CountVectorizer(max_features=100, stop_words="english") x_train_baseline = count_vectorizer.fit_transform(data_train["description"]) y_train_baseline = data_train["category"] x_test_baseline = count_vectorizer.transform(data_test["description"]) y_test_baseline = data_test["category"] x_train_baseline = x_train_baseline.toarray() x_test_baseline = x_test_baseline.toarray()


Tôi đang dự định sử dụng hồi quy logistic làm mô hình, vì vậy tôi cần chuẩn hóa các tính năng của bộ đếm trước khi đào tạo.

 ss = StandardScaler() x_train_baseline = ss.fit_transform(x_train_baseline) x_test_baseline = ss.transform(x_test_baseline) lr = LogisticRegression() lr.fit(x_train_baseline, y_train_baseline) balanced_accuracy_score(y_test_baseline, lr.predict(x_test_baseline)) # >>> 0.752


Hồi quy logistic đa lớp cho thấy độ chính xác cân bằng 75,2%. Đây là một cơ sở tuyệt vời!


Mặc dù chất lượng phân loại tổng thể không tốt nhưng mô hình vẫn có thể cung cấp cho chúng ta một số thông tin chi tiết. Chúng ta hãy xem ma trận nhầm lẫn, được chuẩn hóa bằng số lượng nhãn được dự đoán. Trục X biểu thị danh mục được dự đoán và trục Y - danh mục thực. Nhìn vào từng cột chúng ta có thể thấy sự phân bổ của các danh mục thực tế khi dự đoán một danh mục nhất định.


Ma trận nhầm lẫn cho giải pháp cơ bản.


Ví dụ: Electronics thường bị nhầm lẫn với Household . Nhưng ngay cả mẫu đơn giản này cũng có thể chụp Clothing & Accessories khá chính xác.


Dưới đây là tầm quan trọng của tính năng khi dự đoán danh mục Clothing & Accessories :

Tầm quan trọng của tính năng đối với giải pháp cơ bản cho nhãn 'Quần áo & Phụ kiện'


Top 6 từ đóng góp nhiều nhất cho và chống lại danh mục Clothing & Accessories :

 women 1.49 book -2.03 men 0.93 table -1.47 cotton 0.92 author -1.11 wear 0.69 books -1.10 fit 0.40 led -0.90 stainless 0.36 cable -0.85


RNN

Bây giờ chúng ta hãy xem xét các mô hình nâng cao hơn, được thiết kế đặc biệt để hoạt động với các chuỗi - mạng thần kinh tái phát . GRULSTM là các lớp nâng cao phổ biến để chống lại sự bùng nổ độ dốc xảy ra trong các RNN đơn giản.


Chúng tôi sẽ sử dụng thư viện pytorch để mã hóa các mô tả cũng như xây dựng và huấn luyện mô hình.


Đầu tiên chúng ta cần chuyển văn bản thành số:

  1. Chia mô tả thành các từ
  2. Gán chỉ mục cho mỗi từ trong kho văn bản dựa trên tập dữ liệu huấn luyện
  3. Dự trữ các chỉ số đặc biệt cho các từ và phần đệm chưa biết
  4. Chuyển đổi từng mô tả trong tập dữ liệu huấn luyện và kiểm tra thành vectơ chỉ số.


Từ vựng mà chúng tôi nhận được chỉ bằng cách mã hóa tập dữ liệu xe lửa là rất lớn - gần 90 nghìn từ. Chúng ta càng có nhiều từ thì không gian nhúng mà mô hình phải học càng lớn. Để đơn giản hóa việc đào tạo, hãy xóa những từ hiếm nhất khỏi nó và chỉ để lại những từ xuất hiện trong ít nhất 3% mô tả. Điều này sẽ cắt bớt từ vựng xuống còn 340 từ.

(tìm cách triển khai CorpusDictionary đầy đủ tại đây )


 corpus_dict = util.CorpusDictionary(data_train["description"]) corpus_dict.truncate_dictionary(min_frequency=0.03) data_train["vector"] = corpus_dict.transform(data_train["description"]) data_test["vector"] = corpus_dict.transform(data_test["description"]) print(data_train["vector"].head()) # 28453 [1, 1, 1, 1, 12, 1, 2, 1, 6, 1, 1, 1, 1, 1, 6,... # 48884 [1, 1, 13, 34, 3, 1, 1, 38, 12, 21, 2, 1, 37, ... # 36550 [1, 60, 61, 1, 62, 60, 61, 1, 1, 1, 1, 10, 1, ... # 34999 [1, 34, 1, 1, 75, 60, 61, 1, 1, 72, 1, 1, 67, ... # 19183 [1, 83, 1, 1, 87, 1, 1, 1, 12, 21, 42, 1, 2, 1... # Name: vector, dtype: object


Điều tiếp theo chúng ta cần quyết định là độ dài chung của các vectơ mà chúng ta sẽ đưa vào làm đầu vào cho RNN. Chúng tôi không muốn sử dụng vectơ đầy đủ vì mô tả dài nhất chứa 9,4k mã thông báo.


Tuy nhiên, 95% mô tả trong tập dữ liệu tàu không dài hơn 352 mã thông báo - đó là độ dài phù hợp để cắt bớt. Điều gì sẽ xảy ra với những mô tả ngắn hơn?


Chúng sẽ được đệm bằng chỉ số đệm theo chiều dài chung.

 print(max(data_train["vector"].apply(len))) # >>> 9388 print(int(np.quantile(data_train["vector"].apply(len), q=0.95))) # >>> 352


Tiếp theo - chúng ta cần chuyển đổi các danh mục mục tiêu thành các vectơ 0-1 để tính toán tổn thất và thực hiện lan truyền ngược trên mỗi bước huấn luyện.

 def get_target(label, total_labels=4): target = [0] * total_labels target[label_2_idx.get(label)] = 1 return target data_train["target"] = data_train["category"].apply(get_target) data_test["target"] = data_test["category"].apply(get_target)


Bây giờ chúng ta đã sẵn sàng tạo Bộ dữ liệu và Trình tải dữ liệu pytorch tùy chỉnh để đưa vào mô hình. Tìm bản triển khai PaddedTextVectorDataset đầy đủ tại đây .

 ds_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], corpus_dict, max_vector_len=352, ) ds_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], corpus_dict, max_vector_len=352, ) train_dl = DataLoader(ds_train, batch_size=512, shuffle=True) test_dl = DataLoader(ds_test, batch_size=512, shuffle=False)


Cuối cùng, hãy xây dựng một mô hình.


Kiến trúc tối thiểu là:

  • lớp nhúng
  • lớp RNN
  • lớp tuyến tính
  • lớp kích hoạt


Bắt đầu với các giá trị tham số nhỏ (kích thước của vectơ nhúng, kích thước của lớp ẩn trong RNN, số lượng lớp RNN) và không chính quy hóa, chúng ta có thể dần dần làm cho mô hình trở nên phức tạp hơn cho đến khi nó có dấu hiệu quá khớp, sau đó cân bằng chính quy hóa (bỏ lớp RNN và trước lớp tuyến tính cuối cùng).


 class GRU(nn.Module): def __init__(self, vocab_size, embedding_dim, n_hidden, n_out): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.n_hidden = n_hidden self.n_out = n_out self.emb = nn.Embedding(self.vocab_size, self.embedding_dim) self.gru = nn.GRU(self.embedding_dim, self.n_hidden) self.dropout = nn.Dropout(0.3) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self._init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) gru_out, self.hidden = self.gru(embs, self.hidden) gru_out, lengths = pad_packed_sequence(gru_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def _init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


Chúng tôi sẽ sử dụng trình tối ưu hóa Adamcross_entropy làm hàm mất mát.


 vocab_size = len(corpus_dict.word_to_idx) emb_dim = 4 n_hidden = 15 n_out = len(label_2_idx) model = GRU(vocab_size, emb_dim, n_hidden, n_out) opt = optim.Adam(model.parameters(), 1e-2) util.fit( model=model, train_dl=train_dl, test_dl=test_dl, loss_fn=F.cross_entropy, opt=opt, epochs=35 ) # >>> Train loss: 0.3783 # >>> Val loss: 0.4730 

Tổn thất đào tạo và kiểm tra mỗi kỷ nguyên, mô hình RNN

Mô hình này cho thấy độ chính xác cân bằng 84,3% trên tập dữ liệu eval. Wow, thật là tiến bộ!


Giới thiệu các phần nhúng được đào tạo trước

Nhược điểm chính của việc đào tạo mô hình RNN từ đầu là nó phải học nghĩa của chính các từ - đó là công việc của lớp nhúng. Các mô hình word2vec được đào tạo trước có sẵn để sử dụng làm lớp nhúng được tạo sẵn, giúp giảm số lượng tham số và tăng thêm nhiều ý nghĩa hơn cho mã thông báo. Hãy sử dụng một trong các mô hình word2vec có sẵn trong pytorch - glove, dim=300 .


Chúng tôi chỉ cần thực hiện những thay đổi nhỏ đối với việc tạo Tập dữ liệu - giờ đây chúng tôi muốn tạo một vectơ gồm các chỉ mục được xác định trước glove cho từng mô tả và kiến trúc mô hình.

 ds_emb_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], emb=glove, max_vector_len=max_len, ) ds_emb_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], emb=glove, max_vector_len=max_len, ) dl_emb_train = DataLoader(ds_emb_train, batch_size=512, shuffle=True) dl_emb_test = DataLoader(ds_emb_test, batch_size=512, shuffle=False)
 import torchtext.vocab as vocab glove = vocab.GloVe(name='6B', dim=300) class LSTMPretrained(nn.Module): def __init__(self, n_hidden, n_out): super().__init__() self.emb = nn.Embedding.from_pretrained(glove.vectors) self.emb.requires_grad_ = False self.embedding_dim = 300 self.n_hidden = n_hidden self.n_out = n_out self.lstm = nn.LSTM(self.embedding_dim, self.n_hidden, num_layers=1) self.dropout = nn.Dropout(0.5) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self.init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) lstm_out, (self.hidden, _) = self.lstm(embs) lstm_out, lengths = pad_packed_sequence(lstm_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


Và chúng tôi đã sẵn sàng để đào tạo!

 n_hidden = 50 n_out = len(label_2_idx) emb_model = LSTMPretrained(n_hidden, n_out) opt = optim.Adam(emb_model.parameters(), 1e-2) util.fit(model=emb_model, train_dl=dl_emb_train, test_dl=dl_emb_test, loss_fn=F.cross_entropy, opt=opt, epochs=11) 

Tổn thất đào tạo và kiểm tra trên mỗi kỷ nguyên, mô hình RNN + phần nhúng được đào tạo trước

Bây giờ chúng tôi đang nhận được độ chính xác cân bằng 93,7% trên tập dữ liệu eval. Ôi!


BERT

Các mô hình hiện đại nhất để làm việc với trình tự là máy biến áp. Tuy nhiên, để huấn luyện máy biến áp từ đầu, chúng ta sẽ cần một lượng lớn dữ liệu và tài nguyên tính toán. Những gì chúng ta có thể thử ở đây - là tinh chỉnh một trong những mô hình được đào tạo trước để phục vụ mục đích của chúng ta. Để làm điều này, chúng ta cần tải xuống mô hình BERT được đào tạo trước và thêm lớp tuyến tính và lớp tuyến tính để có được dự đoán cuối cùng. Bạn nên đào tạo một mô hình đã điều chỉnh trong 4 kỷ nguyên. Tôi chỉ đào tạo thêm 2 kỷ nguyên để tiết kiệm thời gian - tôi mất 40 phút để làm điều đó.


 from transformers import BertModel class BERTModel(nn.Module): def __init__(self, n_out=12): super(BERTModel, self).__init__() self.l1 = BertModel.from_pretrained('bert-base-uncased') self.l2 = nn.Dropout(0.3) self.l3 = nn.Linear(768, n_out) def forward(self, ids, mask, token_type_ids): output_1 = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) output_2 = self.l2(output_1.pooler_output) output = self.l3(output_2) return output


 ds_train_bert = bert.get_dataset( list(data_train["description"]), list(data_train["target"]), max_vector_len=64 ) ds_test_bert = bert.get_dataset( list(data_test["description"]), list(data_test["target"]), max_vector_len=64 ) dl_train_bert = DataLoader(ds_train_bert, sampler=RandomSampler(ds_train_bert), batch_size=batch_size) dl_test_bert = DataLoader(ds_test_bert, sampler=SequentialSampler(ds_test_bert), batch_size=batch_size)


 b_model = bert.BERTModel(n_out=4) b_model.to(torch.device("cpu")) def loss_fn(outputs, targets): return torch.nn.BCEWithLogitsLoss()(outputs, targets) optimizer = optim.AdamW(b_model.parameters(), lr=2e-5, eps=1e-8) epochs = 2 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps ) bert.fit(b_model, dl_train_bert, dl_test_bert, optimizer, scheduler, loss_fn, device, epochs=epochs) torch.save(b_model, "models/bert_fine_tuned")


Nhật ký đào tạo:

 2024-02-29 19:38:13.383953 Epoch 1 / 2 Training... 2024-02-29 19:40:39.303002 step 40 / 305 done 2024-02-29 19:43:04.482043 step 80 / 305 done 2024-02-29 19:45:27.767488 step 120 / 305 done 2024-02-29 19:47:53.156420 step 160 / 305 done 2024-02-29 19:50:20.117272 step 200 / 305 done 2024-02-29 19:52:47.988203 step 240 / 305 done 2024-02-29 19:55:16.812437 step 280 / 305 done 2024-02-29 19:56:46.990367 Average training loss: 0.18 2024-02-29 19:56:46.990932 Validating... 2024-02-29 19:57:51.182859 Average validation loss: 0.10 2024-02-29 19:57:51.182948 Epoch 2 / 2 Training... 2024-02-29 20:00:25.110818 step 40 / 305 done 2024-02-29 20:02:56.240693 step 80 / 305 done 2024-02-29 20:05:25.647311 step 120 / 305 done 2024-02-29 20:07:53.668489 step 160 / 305 done 2024-02-29 20:10:33.936778 step 200 / 305 done 2024-02-29 20:13:03.217450 step 240 / 305 done 2024-02-29 20:15:28.384958 step 280 / 305 done 2024-02-29 20:16:57.004078 Average training loss: 0.08 2024-02-29 20:16:57.004657 Validating... 2024-02-29 20:18:01.546235 Average validation loss: 0.09


Cuối cùng, mô hình BERT được tinh chỉnh cho thấy độ chính xác cân bằng lên tới 95,1% trên tập dữ liệu eval.


Chọn người chiến thắng của chúng tôi

Chúng tôi đã thiết lập một danh sách các cân nhắc cần xem xét để đưa ra lựa chọn sáng suốt cuối cùng.

Dưới đây là biểu đồ hiển thị các thông số có thể đo được:

Chỉ số hiệu suất của mô hình


Mặc dù BERT được tinh chỉnh đang dẫn đầu về chất lượng, nhưng RNN với lớp nhúng được đào tạo trước LSTM+EMB chỉ đứng thứ hai, chỉ kém 3% so với các bài tập danh mục tự động.


Mặt khác, thời gian suy luận của BERT tinh chỉnh dài hơn 14 lần so với LSTM+EMB . Điều này sẽ làm tăng thêm chi phí bảo trì phụ trợ và có thể sẽ lớn hơn những lợi ích BERT tinh chỉnh mang lại cho LSTM+EMB .


Về khả năng tương tác, mô hình hồi quy logistic cơ bản của chúng tôi cho đến nay là dễ hiểu nhất và bất kỳ mạng thần kinh nào cũng thua nó về mặt này. Đồng thời, đường cơ sở có lẽ có khả năng mở rộng ít nhất - việc thêm các danh mục sẽ làm giảm chất lượng vốn đã thấp của đường cơ sở.


Mặc dù BERT có vẻ giống như siêu sao với độ chính xác cao, nhưng cuối cùng chúng tôi vẫn sử dụng RNN với lớp nhúng được đào tạo trước. Tại sao? Nó khá chính xác, không quá chậm và không quá phức tạp để xử lý khi mọi thứ trở nên lớn.


Hy vọng bạn thích nghiên cứu trường hợp này. Bạn sẽ chọn mô hình nào và tại sao?