paint-brush
Une étude de cas de classification de textes d'apprentissage automatique avec une touche axée sur le produitpar@bemorelavender
29,341 lectures
29,341 lectures

Une étude de cas de classification de textes d'apprentissage automatique avec une touche axée sur le produit

par Maria K17m2024/03/12
Read on Terminal Reader

Trop long; Pour lire

Il s'agit d'une étude de cas d'apprentissage automatique avec une touche axée sur le produit : nous allons prétendre que nous avons un produit réel que nous devons améliorer. Nous explorerons un ensemble de données et testerons différents modèles tels que la régression logistique, les réseaux de neurones récurrents et les transformateurs, en examinant leur précision, comment ils vont améliorer le produit, leur rapidité de fonctionnement et s'ils sont faciles à déboguer. et passer à l'échelle.
featured image - Une étude de cas de classification de textes d'apprentissage automatique avec une touche axée sur le produit
Maria K HackerNoon profile picture


Nous allons prétendre que nous avons un produit réel que nous devons améliorer. Nous explorerons un ensemble de données et testerons différents modèles tels que la régression logistique, les réseaux de neurones récurrents et les transformateurs, en examinant leur précision, comment ils vont améliorer le produit, leur rapidité de fonctionnement et s'ils sont faciles à déboguer. et passer à l'échelle.


Vous pouvez lire le code complet de l'étude de cas sur GitHub et consulter le bloc-notes d'analyse avec des graphiques interactifs dans Jupyter Notebook Viewer .


Excité? Allons-y !

Paramétrage des tâches

Imaginez que nous possédons un site Web de commerce électronique. Sur ce site, le vendeur peut télécharger les descriptions des articles qu'il souhaite vendre. Ils doivent également choisir manuellement les catégories d'articles, ce qui peut les ralentir.


Notre tâche est d'automatiser le choix des catégories en fonction de la description de l'article. Cependant, un choix mal automatisé est pire que pas d’automatisation, car une erreur peut passer inaperçue, ce qui peut entraîner des pertes de ventes. Par conséquent, nous pouvons choisir de ne pas définir d’étiquette automatique si nous n’en sommes pas sûrs.


Pour cette étude de cas, nous utiliserons le Ensemble de données textuelles Zenodo E-commerce , contenant des descriptions et des catégories d'éléments.


Bon ou Mauvais? Comment choisir le meilleur modèle

Nous examinerons plusieurs architectures de modèles ci-dessous et c'est toujours une bonne pratique de décider comment choisir la meilleure option avant de commencer. Quel impact ce modèle va-t-il avoir sur notre produit ? …nos infrastructures ?


Évidemment, nous aurons une métrique de qualité technique pour comparer différents modèles hors ligne. Dans ce cas, nous avons une tâche de classification multi-classes, utilisons donc un score de précision équilibré , qui gère bien les étiquettes déséquilibrées.


Bien entendu, l'étape finale typique du test d'un candidat est le test AB - l'étape en ligne, qui donne une meilleure idée de la façon dont les clients sont affectés par le changement. Habituellement, les tests AB prennent plus de temps que les tests hors ligne, c'est pourquoi seuls les meilleurs candidats de l'étape hors ligne sont testés. Il s'agit d'une étude de cas, et nous n'avons pas d'utilisateurs réels, nous n'allons donc pas couvrir les tests AB.


Que devrions-nous prendre en compte avant de faire passer un candidat à l'AB-test ? À quoi pouvons-nous penser pendant la phase hors ligne pour gagner du temps de test en ligne et nous assurer que nous testons réellement la meilleure solution possible ?


Transformer les métriques techniques en métriques orientées impact

La précision équilibrée est excellente, mais ce score ne répond pas à la question « Quel impact exact le modèle va-t-il avoir sur le produit ? ». Pour trouver un score plus orienté produit, nous devons comprendre comment nous allons utiliser le modèle.


Dans notre contexte, faire une erreur est pire que ne pas répondre, car le vendeur devra constater l'erreur et changer de catégorie manuellement. Une erreur inaperçue diminuera les ventes et aggravera l'expérience utilisateur du vendeur, nous risquons de perdre des clients.


Pour éviter cela, nous choisirons des seuils pour le score du modèle afin de ne nous permettre que 1% d'erreurs. La métrique orientée produit peut alors être définie comme suit :


Quel pourcentage d'articles pouvons-nous catégoriser automatiquement si notre tolérance d'erreur n'est que de 1 % ?


Nous ferons référence à cela sous le nom Automatic categorisation percentage ci-dessous lors de la sélection du meilleur modèle. Retrouvez le code complet de sélection du seuil ici .


Temps d'inférence

Combien de temps faut-il à un modèle pour traiter une demande ?


Cela nous permettra approximativement de comparer la quantité de ressources supplémentaires que nous devrons maintenir pour qu'un service puisse gérer la charge de tâche si un modèle est sélectionné plutôt qu'un autre.


Évolutivité

Lorsque notre produit va se développer, dans quelle mesure sera-t-il facile de gérer cette croissance en utilisant une architecture donnée ?


Par croissance, nous pourrions entendre :

  • plus de catégories, une granularité plus élevée des catégories
  • descriptions plus longues
  • des ensembles de données plus grands
  • etc.

Faudra-t-il repenser un choix de modèle pour faire face à la croissance ou une simple reconversion suffira ?


Interprétabilité

Dans quelle mesure sera-t-il facile de déboguer les erreurs du modèle pendant la formation et après le déploiement ?


Taille du modèle

La taille du modèle est importante si :

  • nous voulons que notre modèle soit évalué côté client
  • il est si gros qu'il ne peut pas rentrer dans la RAM


Nous verrons plus tard que les deux éléments ci-dessus ne sont pas pertinents, mais cela vaut quand même la peine d'y réfléchir brièvement.

Exploration et nettoyage des ensembles de données

Avec quoi travaillons-nous ? Examinons les données et voyons si elles doivent être nettoyées !


L'ensemble de données contient 2 colonnes : description de l'élément et catégorie, un total de 50,5 000 lignes.

 file_name = "ecommerceDataset.csv" data = pd.read_csv(file_name, header=None) data.columns = ["category", "description"] print("Rows, cols:", data.shape) # >>> Rows, cols: (50425, 2)


Chaque article se voit attribuer 1 des 4 catégories disponibles : Household , Books , Electronics ou Clothing & Accessories . Voici 1 exemple de description d'article par catégorie :


  • Household SPK Home decor Clay Handmade Wall Hanging Face (Multicolore, H35xW12cm) Rendez votre maison plus belle avec cette tenture murale de masque indien en terre cuite faite à la main, jamais auparavant vous ne pourrez pas attraper cette chose faite à la main sur le marché. Vous pouvez l'ajouter à votre salon/hall d'entrée.


  • Livres BEGF101/FEG1-Cours de base en anglais-1 (Neeraj Publications édition 2018) BEGF101/FEG1-Cours de base en anglais-1


  • Vêtements et accessoires Salopette en denim Broadstar pour femme Obtenez un pass accès illimité en portant une salopette Broadstar. Fabriquée en denim, cette salopette vous gardera à l'aise. Associez-les à un haut de couleur blanche ou noire pour compléter votre look décontracté.


  • Electronics Caprigo Heavy Duty – Support de montage au plafond pour projecteur de qualité supérieure (réglable – Blanc – Capacité de poids 15 kg)


Valeurs manquantes

Il n'y a qu'une seule valeur vide dans l'ensemble de données, que nous allons supprimer.

 print(data.info()) # <class 'pandas.core.frame.DataFrame'> # RangeIndex: 50425 entries, 0 to 50424 # Data columns (total 2 columns): # # Column Non-Null Count Dtype # --- ------ -------------- ----- # 0 category 50425 non-null object # 1 description 50424 non-null object # dtypes: object(2) # memory usage: 788.0+ KB data.dropna(inplace=True)


Doublons

Il existe cependant de nombreuses descriptions en double. Heureusement, tous les doublons appartiennent à une seule catégorie, nous pouvons donc les supprimer en toute sécurité.

 repeated_messages = data \ .groupby("description", as_index=False) \ .agg( n_repeats=("category", "count"), n_unique_categories=("category", lambda x: len(np.unique(x))) ) repeated_messages = repeated_messages[repeated_messages["n_repeats"] > 1] print(f"Count of repeated messages (unique): {repeated_messages.shape[0]}") print(f"Total number: {repeated_messages['n_repeats'].sum()} out of {data.shape[0]}") # >>> Count of repeated messages (unique): 13979 # >>> Total number: 36601 out of 50424


Après avoir supprimé les doublons, il nous reste 55 % de l'ensemble de données d'origine. L'ensemble de données est bien équilibré.

 data.drop_duplicates(inplace=True) print(f"New dataset size: {data.shape}") print(data["category"].value_counts()) # New dataset size: (27802, 2) # Household 10564 # Books 6256 # Clothing & Accessories 5674 # Electronics 5308 # Name: category, dtype: int64


Langue de description

Notez que selon la description de l'ensemble de données,

L'ensemble de données a été extrait de la plateforme de commerce électronique indienne.


Les descriptions ne sont pas nécessairement rédigées en anglais. Certains d'entre eux sont écrits en hindi ou dans d'autres langues en utilisant des symboles non-ASCII ou translittérés en alphabet latin, ou utilisent un mélange de langues. Exemples de la catégorie Books :


  • यू जी सी – नेट जूनियर रिसर्च फैलोशिप एवं सहायक प्रोफेसर योग्यता …
  • Prarambhik Bhartiy Itihas
  • History of NORTH INDIA/வட இந்திய வரலாறு/ …


Pour évaluer la présence de mots non anglais dans les descriptions, calculons 2 scores :


  • Score ASCII : pourcentage de symboles non-ASCII dans une description
  • Score des mots anglais valides : si l’on considère uniquement les lettres latines, quel pourcentage de mots dans la description sont valides en anglais ? Disons que les mots anglais valides sont ceux présents dans Word2Vec-300 formés sur un corpus anglais.


En utilisant le score ASCII, nous apprenons que seulement 2,3 % des descriptions contiennent plus de 1 % de symboles non-ASCII.

 def get_ascii_score(description): total_sym_cnt = 0 ascii_sym_cnt = 0 for sym in description: total_sym_cnt += 1 if sym.isascii(): ascii_sym_cnt += 1 return ascii_sym_cnt / total_sym_cnt data["ascii_score"] = data["description"].apply(get_ascii_score) data[data["ascii_score"] < 0.99].shape[0] / data.shape[0] # >>> 0.023


Le score des mots anglais valides montre que seulement 1,5 % des descriptions contiennent moins de 70 % de mots anglais valides parmi les mots ASCII.

 w2v_eng = gensim.models.KeyedVectors.load_word2vec_format(w2v_path, binary=True) def get_valid_eng_score(description): description = re.sub("[^az \t]+", " ", description.lower()) total_word_cnt = 0 eng_word_cnt = 0 for word in description.split(): total_word_cnt += 1 if word.lower() in w2v_eng: eng_word_cnt += 1 return eng_word_cnt / total_word_cnt data["eng_score"] = data["description"].apply(get_valid_eng_score) data[data["eng_score"] < 0.7].shape[0] / data.shape[0] # >>> 0.015


Ainsi la majorité des descriptions (environ 96 %) sont en anglais ou majoritairement en anglais. Nous pouvons supprimer toutes les autres descriptions, mais laissons-les telles quelles et voyons ensuite comment chaque modèle les gère.

La modélisation

Divisons notre ensemble de données en 3 groupes :

  • Former 70 % - pour former les modèles (19 000 messages)

  • Test 15 % - pour le choix des paramètres et des seuils (4 100 messages)

  • Eval 15% - pour le choix du modèle final (4,1k messages)


 from sklearn.model_selection import train_test_split data_train, data_test = train_test_split(data, test_size=0.3) data_test, data_eval = train_test_split(data_test, test_size=0.5) data_train.shape, data_test.shape, data_eval.shape # >>> ((19461, 3), (4170, 3), (4171, 3))


Modèle de référence : sac de mots + régression logistique

Il est utile de faire quelque chose de simple et trivial au début pour obtenir une bonne base de référence. Comme base de référence, créons une structure de sac de mots basée sur l'ensemble de données du train.


Limitons également la taille du dictionnaire à 100 mots.

 count_vectorizer = CountVectorizer(max_features=100, stop_words="english") x_train_baseline = count_vectorizer.fit_transform(data_train["description"]) y_train_baseline = data_train["category"] x_test_baseline = count_vectorizer.transform(data_test["description"]) y_test_baseline = data_test["category"] x_train_baseline = x_train_baseline.toarray() x_test_baseline = x_test_baseline.toarray()


Je prévois d'utiliser la régression logistique comme modèle, je dois donc normaliser les fonctionnalités du compteur avant l'entraînement.

 ss = StandardScaler() x_train_baseline = ss.fit_transform(x_train_baseline) x_test_baseline = ss.transform(x_test_baseline) lr = LogisticRegression() lr.fit(x_train_baseline, y_train_baseline) balanced_accuracy_score(y_test_baseline, lr.predict(x_test_baseline)) # >>> 0.752


La régression logistique multiclasse a montré une précision équilibrée de 75,2 %. C'est une excellente base de référence !


Bien que la qualité globale de la classification ne soit pas excellente, le modèle peut quand même nous donner quelques indications. Regardons la matrice de confusion, normalisée par le nombre d'étiquettes prédites. L'axe X désigne la catégorie prédite et l'axe Y représente la catégorie réelle. En regardant chaque colonne, nous pouvons voir la distribution des catégories réelles lorsqu'une certaine catégorie a été prédite.


Matrice de confusion pour la solution de base.


Par exemple, Electronics est fréquemment confondue avec Household . Mais même ce modèle simple peut capturer Clothing & Accessories de manière assez précise.


Voici l’importance des fonctionnalités lors de la prévision de la catégorie Clothing & Accessories :

Importance des fonctionnalités pour la solution de base pour l'étiquette « Vêtements et accessoires »


Top 6 des mots les plus favorables et défavorables à la catégorie Clothing & Accessories :

 women 1.49 book -2.03 men 0.93 table -1.47 cotton 0.92 author -1.11 wear 0.69 books -1.10 fit 0.40 led -0.90 stainless 0.36 cable -0.85


RNN

Considérons maintenant des modèles plus avancés, conçus spécifiquement pour fonctionner avec des séquences - réseaux de neurones récurrents . GRU et LSTM sont des couches avancées communes pour lutter contre les gradients explosifs qui se produisent dans les RNN simples.


Nous utiliserons la bibliothèque pytorch pour tokeniser les descriptions, ainsi que pour créer et entraîner un modèle.


Tout d’abord, nous devons transformer les textes en chiffres :

  1. Diviser les descriptions en mots
  2. Attribuez un index à chaque mot du corpus en fonction de l'ensemble de données d'entraînement
  3. Réservez des index spéciaux pour les mots inconnus et le remplissage
  4. Transformez chaque description des ensembles de données d'entraînement et de test en vecteurs d'indices.


Le vocabulaire que nous obtenons en tokenisant simplement l'ensemble de données du train est volumineux - près de 90 000 mots. Plus nous avons de mots, plus l'espace d'intégration que le modèle doit apprendre est grand. Pour simplifier la formation, supprimons-en les mots les plus rares et ne laissons que ceux qui apparaissent dans au moins 3% des descriptions. Cela tronquera le vocabulaire à 340 mots.

(trouvez l'implémentation complète CorpusDictionary ici )


 corpus_dict = util.CorpusDictionary(data_train["description"]) corpus_dict.truncate_dictionary(min_frequency=0.03) data_train["vector"] = corpus_dict.transform(data_train["description"]) data_test["vector"] = corpus_dict.transform(data_test["description"]) print(data_train["vector"].head()) # 28453 [1, 1, 1, 1, 12, 1, 2, 1, 6, 1, 1, 1, 1, 1, 6,... # 48884 [1, 1, 13, 34, 3, 1, 1, 38, 12, 21, 2, 1, 37, ... # 36550 [1, 60, 61, 1, 62, 60, 61, 1, 1, 1, 1, 10, 1, ... # 34999 [1, 34, 1, 1, 75, 60, 61, 1, 1, 72, 1, 1, 67, ... # 19183 [1, 83, 1, 1, 87, 1, 1, 1, 12, 21, 42, 1, 2, 1... # Name: vector, dtype: object


La prochaine chose que nous devons décider est la longueur commune des vecteurs que nous allons alimenter en entrées dans RNN. Nous ne voulons pas utiliser de vecteurs complets, car la description la plus longue contient 9,4 000 jetons.


Cependant, 95 % des descriptions de l'ensemble de données de train ne dépassent pas 352 jetons - c'est une bonne longueur pour le découpage. Que va-t-il se passer avec des descriptions plus courtes ?


Ils vont être rembourrés avec un index de rembourrage jusqu'à la longueur commune.

 print(max(data_train["vector"].apply(len))) # >>> 9388 print(int(np.quantile(data_train["vector"].apply(len), q=0.95))) # >>> 352


Ensuite, nous devons transformer les catégories cibles en vecteurs 0-1 pour calculer la perte et effectuer une rétro-propagation à chaque étape de formation.

 def get_target(label, total_labels=4): target = [0] * total_labels target[label_2_idx.get(label)] = 1 return target data_train["target"] = data_train["category"].apply(get_target) data_test["target"] = data_test["category"].apply(get_target)


Nous sommes maintenant prêts à créer un ensemble de données et un chargeur de données pytorch personnalisés pour alimenter le modèle. Trouvez l’implémentation complète PaddedTextVectorDataset ici .

 ds_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], corpus_dict, max_vector_len=352, ) ds_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], corpus_dict, max_vector_len=352, ) train_dl = DataLoader(ds_train, batch_size=512, shuffle=True) test_dl = DataLoader(ds_test, batch_size=512, shuffle=False)


Enfin, construisons un modèle.


L'architecture minimale est :

  • couche d'incorporation
  • Couche RNN
  • couche linéaire
  • couche d'activation


En commençant par de petites valeurs de paramètres (taille du vecteur d'intégration, taille d'une couche cachée dans RNN, nombre de couches RNN) et sans régularisation, nous pouvons progressivement rendre le modèle plus compliqué jusqu'à ce qu'il montre de forts signes de surajustement, puis équilibrer régularisation (abandons dans la couche RNN et avant la dernière couche linéaire).


 class GRU(nn.Module): def __init__(self, vocab_size, embedding_dim, n_hidden, n_out): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.n_hidden = n_hidden self.n_out = n_out self.emb = nn.Embedding(self.vocab_size, self.embedding_dim) self.gru = nn.GRU(self.embedding_dim, self.n_hidden) self.dropout = nn.Dropout(0.3) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self._init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) gru_out, self.hidden = self.gru(embs, self.hidden) gru_out, lengths = pad_packed_sequence(gru_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def _init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


Nous utiliserons l'optimiseur Adam et cross_entropy comme fonction de perte.


 vocab_size = len(corpus_dict.word_to_idx) emb_dim = 4 n_hidden = 15 n_out = len(label_2_idx) model = GRU(vocab_size, emb_dim, n_hidden, n_out) opt = optim.Adam(model.parameters(), 1e-2) util.fit( model=model, train_dl=train_dl, test_dl=test_dl, loss_fn=F.cross_entropy, opt=opt, epochs=35 ) # >>> Train loss: 0.3783 # >>> Val loss: 0.4730 

Entraîner et tester les pertes par époque, modèle RNN

Ce modèle a montré une précision équilibrée de 84,3 % sur l'ensemble de données d'évaluation. Wow, quel progrès !


Présentation des intégrations pré-entraînées

L'inconvénient majeur de la formation du modèle RNN à partir de zéro est qu'il doit apprendre lui-même la signification des mots - c'est le travail de la couche d'intégration. Des modèles word2vec pré-entraînés sont disponibles pour être utilisés comme couche d'intégration prête à l'emploi, ce qui réduit le nombre de paramètres et ajoute beaucoup plus de signification aux jetons. Utilisons l'un des modèles word2vec disponibles dans pytorch - glove, dim=300 .


Nous n'avons besoin que d'apporter des modifications mineures à la création du jeu de données : nous souhaitons maintenant créer un vecteur d'index glove pour chaque description et l'architecture du modèle.

 ds_emb_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], emb=glove, max_vector_len=max_len, ) ds_emb_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], emb=glove, max_vector_len=max_len, ) dl_emb_train = DataLoader(ds_emb_train, batch_size=512, shuffle=True) dl_emb_test = DataLoader(ds_emb_test, batch_size=512, shuffle=False)
 import torchtext.vocab as vocab glove = vocab.GloVe(name='6B', dim=300) class LSTMPretrained(nn.Module): def __init__(self, n_hidden, n_out): super().__init__() self.emb = nn.Embedding.from_pretrained(glove.vectors) self.emb.requires_grad_ = False self.embedding_dim = 300 self.n_hidden = n_hidden self.n_out = n_out self.lstm = nn.LSTM(self.embedding_dim, self.n_hidden, num_layers=1) self.dropout = nn.Dropout(0.5) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self.init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) lstm_out, (self.hidden, _) = self.lstm(embs) lstm_out, lengths = pad_packed_sequence(lstm_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


Et nous sommes prêts à nous entraîner !

 n_hidden = 50 n_out = len(label_2_idx) emb_model = LSTMPretrained(n_hidden, n_out) opt = optim.Adam(emb_model.parameters(), 1e-2) util.fit(model=emb_model, train_dl=dl_emb_train, test_dl=dl_emb_test, loss_fn=F.cross_entropy, opt=opt, epochs=11) 

Entraîner et tester les pertes par époque, modèle RNN + intégrations pré-entraînées

Nous obtenons désormais une précision équilibrée de 93,7 % sur l'ensemble de données d'évaluation. Courtiser!


BERTE

Les modèles modernes de pointe pour travailler avec des séquences sont les transformateurs. Cependant, pour former un transformateur à partir de zéro, nous aurions besoin d’énormes quantités de données et de ressources informatiques. Ce que nous pouvons essayer ici, c'est d'affiner l'un des modèles pré-entraînés pour répondre à notre objectif. Pour ce faire, nous devons télécharger un modèle BERT pré-entraîné et ajouter un abandon et une couche linéaire pour obtenir la prédiction finale. Il est recommandé de former un modèle réglé pendant 4 époques. Je ne me suis entraîné que 2 époques supplémentaires pour gagner du temps – cela m'a pris 40 minutes pour le faire.


 from transformers import BertModel class BERTModel(nn.Module): def __init__(self, n_out=12): super(BERTModel, self).__init__() self.l1 = BertModel.from_pretrained('bert-base-uncased') self.l2 = nn.Dropout(0.3) self.l3 = nn.Linear(768, n_out) def forward(self, ids, mask, token_type_ids): output_1 = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) output_2 = self.l2(output_1.pooler_output) output = self.l3(output_2) return output


 ds_train_bert = bert.get_dataset( list(data_train["description"]), list(data_train["target"]), max_vector_len=64 ) ds_test_bert = bert.get_dataset( list(data_test["description"]), list(data_test["target"]), max_vector_len=64 ) dl_train_bert = DataLoader(ds_train_bert, sampler=RandomSampler(ds_train_bert), batch_size=batch_size) dl_test_bert = DataLoader(ds_test_bert, sampler=SequentialSampler(ds_test_bert), batch_size=batch_size)


 b_model = bert.BERTModel(n_out=4) b_model.to(torch.device("cpu")) def loss_fn(outputs, targets): return torch.nn.BCEWithLogitsLoss()(outputs, targets) optimizer = optim.AdamW(b_model.parameters(), lr=2e-5, eps=1e-8) epochs = 2 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps ) bert.fit(b_model, dl_train_bert, dl_test_bert, optimizer, scheduler, loss_fn, device, epochs=epochs) torch.save(b_model, "models/bert_fine_tuned")


Journal d'entraînement :

 2024-02-29 19:38:13.383953 Epoch 1 / 2 Training... 2024-02-29 19:40:39.303002 step 40 / 305 done 2024-02-29 19:43:04.482043 step 80 / 305 done 2024-02-29 19:45:27.767488 step 120 / 305 done 2024-02-29 19:47:53.156420 step 160 / 305 done 2024-02-29 19:50:20.117272 step 200 / 305 done 2024-02-29 19:52:47.988203 step 240 / 305 done 2024-02-29 19:55:16.812437 step 280 / 305 done 2024-02-29 19:56:46.990367 Average training loss: 0.18 2024-02-29 19:56:46.990932 Validating... 2024-02-29 19:57:51.182859 Average validation loss: 0.10 2024-02-29 19:57:51.182948 Epoch 2 / 2 Training... 2024-02-29 20:00:25.110818 step 40 / 305 done 2024-02-29 20:02:56.240693 step 80 / 305 done 2024-02-29 20:05:25.647311 step 120 / 305 done 2024-02-29 20:07:53.668489 step 160 / 305 done 2024-02-29 20:10:33.936778 step 200 / 305 done 2024-02-29 20:13:03.217450 step 240 / 305 done 2024-02-29 20:15:28.384958 step 280 / 305 done 2024-02-29 20:16:57.004078 Average training loss: 0.08 2024-02-29 20:16:57.004657 Validating... 2024-02-29 20:18:01.546235 Average validation loss: 0.09


Enfin, le modèle BERT affiné montre une précision équilibrée de 95,1 % sur l'ensemble de données d'évaluation.


Choisir notre gagnant

Nous avons déjà établi une liste de considérations à prendre en compte pour faire un choix final éclairé.

Voici des graphiques montrant les paramètres mesurables :

Mesures de performances des modèles


Bien que le BERT affiné soit leader en termes de qualité, le RNN avec la couche d'intégration pré-entraînée LSTM+EMB arrive juste derrière, n'étant en retard que de 3 % des attributions automatiques de catégories.


D'un autre côté, le temps d'inférence du BERT affiné est 14 fois plus long que celui LSTM+EMB . Cela entraînera des coûts de maintenance back-end qui dépasseront probablement les avantages apportés BERT ajusté par rapport à LSTM+EMB .


En ce qui concerne l'interopérabilité, notre modèle de régression logistique de base est de loin le plus interprétable et tout réseau neuronal y perd à cet égard. Dans le même temps, la ligne de base est probablement la moins évolutive : l’ajout de catégories diminuera la qualité déjà faible de la ligne de base.


Même si BERT semble être la superstar avec sa grande précision, nous finissons par opter pour le RNN avec une couche d'intégration pré-entraînée. Pourquoi? C'est assez précis, pas trop lent et ne devient pas trop compliqué à gérer lorsque les choses deviennent importantes.


J'espère que vous avez apprécié cette étude de cas. Quel modèle auriez-vous choisi et pourquoi ?