paint-brush
Um estudo de caso de classificação de texto de aprendizado de máquina com um toque orientado ao produtopor@bemorelavender
29,341 leituras
29,341 leituras

Um estudo de caso de classificação de texto de aprendizado de máquina com um toque orientado ao produto

por Maria K17m2024/03/12
Read on Terminal Reader

Muito longo; Para ler

Este é um estudo de caso de aprendizado de máquina com um toque voltado para o produto: vamos fingir que temos um produto real que precisamos melhorar. Exploraremos um conjunto de dados e testaremos diferentes modelos, como regressão logística, redes neurais recorrentes e transformadores, observando quão precisos eles são, como vão melhorar o produto, quão rápido funcionam e se são fáceis de depurar e aumentar a escala.
featured image - Um estudo de caso de classificação de texto de aprendizado de máquina com um toque orientado ao produto
Maria K HackerNoon profile picture


Vamos fingir que temos um produto real que precisamos melhorar. Exploraremos um conjunto de dados e testaremos diferentes modelos, como regressão logística, redes neurais recorrentes e transformadores, observando quão precisos eles são, como vão melhorar o produto, quão rápido funcionam e se são fáceis de depurar e aumentar a escala.


Você pode ler o código completo do estudo de caso no GitHub e ver o caderno de análise com gráficos interativos no Jupyter Notebook Viewer .


Excitado? Vamos lá!

Configuração de tarefa

Imagine que possuímos um site de comércio eletrônico. Neste site, o vendedor pode fazer upload das descrições dos itens que deseja vender. Eles também precisam escolher as categorias dos itens manualmente, o que pode atrasá-los.


Nossa tarefa é automatizar a escolha das categorias com base na descrição do item. Porém, uma escolha erradamente automatizada é pior do que nenhuma automatização, pois um erro pode passar despercebido, o que pode levar a perdas nas vendas. Portanto, podemos optar por não definir um rótulo automatizado se não tivermos certeza.


Para este estudo de caso, usaremos o Conjunto de dados de texto de comércio eletrônico Zenodo , contendo descrições e categorias de itens.


Bom ou mal? Como escolher o melhor modelo

Consideraremos múltiplas arquiteturas de modelos abaixo e é sempre uma boa prática decidir como escolher a melhor opção antes de começar. Como esse modelo afetará nosso produto? …nossa infraestrutura?


Obviamente, teremos uma métrica de qualidade técnica para comparar vários modelos offline. Nesse caso, temos uma tarefa de classificação multiclasse, então vamos usar uma pontuação de precisão balanceada , que lida bem com rótulos desequilibrados.


É claro que o estágio final típico de teste de um candidato é o teste AB – o estágio on-line, que dá uma imagem melhor de como os clientes são afetados pela mudança. Normalmente, os testes AB consomem mais tempo do que os testes offline, portanto, apenas os melhores candidatos do estágio offline são testados. Este é um estudo de caso e não temos usuários reais, portanto não abordaremos os testes AB.


O que mais devemos considerar antes de avançar um candidato para o teste AB? O que podemos pensar durante a fase off-line para economizar algum tempo de teste on-line e ter certeza de que estamos realmente testando a melhor solução possível?


Transformando métricas técnicas em métricas orientadas para o impacto

A precisão equilibrada é ótima, mas essa pontuação não responde à pergunta “Como exatamente o modelo afetará o produto?”. Para encontrar uma pontuação mais orientada ao produto, devemos entender como usaremos o modelo.


No nosso cenário, errar é pior do que não responder, pois o vendedor terá que perceber o erro e alterar a categoria manualmente. Um erro despercebido diminuirá as vendas e piorará a experiência do usuário do vendedor, corremos o risco de perder clientes.


Para evitar isso, escolheremos limites para a pontuação do modelo de forma que nos permitamos apenas 1% de erros. A métrica orientada ao produto pode então ser definida da seguinte forma:


Que porcentagem de itens podemos categorizar automaticamente se nossa tolerância a erros for de apenas 1%?


Iremos nos referir a isso como Automatic categorisation percentage abaixo ao selecionar o melhor modelo. Encontre o código de seleção de limite completo aqui .


Tempo de inferência

Quanto tempo leva para um modelo processar uma solicitação?


Isso nos permitirá comparar quantos recursos a mais teremos que manter para que um serviço lide com a carga de tarefas se um modelo for selecionado em vez de outro.


Escalabilidade

Quando nosso produto crescer, quão fácil será gerenciar o crescimento usando determinada arquitetura?


Por crescimento podemos significar:

  • mais categorias, maior granularidade de categorias
  • descrições mais longas
  • conjuntos de dados maiores
  • etc.

Teremos que repensar a escolha de um modelo para lidar com o crescimento ou uma simples reconversão será suficiente?


Interpretabilidade

Será fácil depurar erros do modelo durante o treinamento e após a implantação?


Tamanho do modelo

O tamanho do modelo é importante se:

  • queremos que nosso modelo seja avaliado do lado do cliente
  • é tão grande que não cabe na RAM


Veremos mais tarde que ambos os itens acima não são relevantes, mas ainda assim vale a pena considerar brevemente.

Exploração e limpeza de conjunto de dados

Com o que estamos trabalhando? Vamos dar uma olhada nos dados e ver se eles precisam de limpeza!


O conjunto de dados contém 2 colunas: descrição do item e categoria, um total de 50,5 mil linhas.

 file_name = "ecommerceDataset.csv" data = pd.read_csv(file_name, header=None) data.columns = ["category", "description"] print("Rows, cols:", data.shape) # >>> Rows, cols: (50425, 2)


Cada item é atribuído a 1 das 4 categorias disponíveis: Household , Books , Electronics ou Clothing & Accessories . Aqui está um exemplo de descrição de item por categoria:


  • Casa SPK Decoração de casa Argila feita à mão para pendurar na parede (multicor, H35xL12cm) Deixe sua casa mais bonita com esta máscara facial indiana de terracota feita à mão para pendurar na parede, nunca antes você não conseguir pegar essa coisa feita à mão no mercado. Você pode adicionar isso à sua sala de estar/átrio de entrada.


  • Livros BEGF101/FEG1-Curso Básico em Inglês-1 (Edição Neeraj Publications 2018) BEGF101/FEG1-Curso Básico em Inglês-1


  • Roupas e acessórios Macacão jeans feminino Broadstar Ganhe um passe de acesso total usando macacão da Broadstar. Feito em jeans, esse macacão vai deixar você confortável. Combine-os com um top branco ou preto para completar o seu look casual.


  • Electronics Caprigo Heavy Duty - Suporte de suporte de montagem no teto para projetor premium de 2 pés (ajustável - Branco - Capacidade de peso 15 Kgs)


Valores ausentes

Há apenas um valor vazio no conjunto de dados, que iremos remover.

 print(data.info()) # <class 'pandas.core.frame.DataFrame'> # RangeIndex: 50425 entries, 0 to 50424 # Data columns (total 2 columns): # # Column Non-Null Count Dtype # --- ------ -------------- ----- # 0 category 50425 non-null object # 1 description 50424 non-null object # dtypes: object(2) # memory usage: 788.0+ KB data.dropna(inplace=True)


Duplicatas

No entanto, existem muitas descrições duplicadas. Felizmente, todas as duplicatas pertencem a uma categoria, então podemos eliminá-las com segurança.

 repeated_messages = data \ .groupby("description", as_index=False) \ .agg( n_repeats=("category", "count"), n_unique_categories=("category", lambda x: len(np.unique(x))) ) repeated_messages = repeated_messages[repeated_messages["n_repeats"] > 1] print(f"Count of repeated messages (unique): {repeated_messages.shape[0]}") print(f"Total number: {repeated_messages['n_repeats'].sum()} out of {data.shape[0]}") # >>> Count of repeated messages (unique): 13979 # >>> Total number: 36601 out of 50424


Depois de remover as duplicatas, ficamos com 55% do conjunto de dados original. O conjunto de dados é bem balanceado.

 data.drop_duplicates(inplace=True) print(f"New dataset size: {data.shape}") print(data["category"].value_counts()) # New dataset size: (27802, 2) # Household 10564 # Books 6256 # Clothing & Accessories 5674 # Electronics 5308 # Name: category, dtype: int64


Descrição Idioma

Observe que, de acordo com a descrição do conjunto de dados,

O conjunto de dados foi extraído da plataforma de comércio eletrônico indiana.


As descrições não são necessariamente escritas em inglês. Alguns deles são escritos em hindi ou em outros idiomas usando símbolos não ASCII ou transliterados para o alfabeto latino, ou usam uma mistura de idiomas. Exemplos da categoria Books :


  • यू जी सी – नेट जूनियर रिसर्च फैलोशिप एवं सहायक प्रोफेसर योग्यता …
  • Prarambhik Bhartiy Itihas
  • History of NORTH INDIA/வட இந்திய வரலாறு/ …


Para avaliar a presença de palavras não inglesas nas descrições, vamos calcular 2 pontuações:


  • Pontuação ASCII: porcentagem de símbolos não ASCII em uma descrição
  • Pontuação de palavras válidas em inglês: se considerarmos apenas letras latinas, qual a porcentagem de palavras na descrição que são válidas em inglês? Digamos que palavras válidas em inglês sejam aquelas presentes no Word2Vec-300 treinado em um corpus inglês.


Usando a pontuação ASCII, aprendemos que apenas 2,3% das descrições consistem em mais de 1% de símbolos não ASCII.

 def get_ascii_score(description): total_sym_cnt = 0 ascii_sym_cnt = 0 for sym in description: total_sym_cnt += 1 if sym.isascii(): ascii_sym_cnt += 1 return ascii_sym_cnt / total_sym_cnt data["ascii_score"] = data["description"].apply(get_ascii_score) data[data["ascii_score"] < 0.99].shape[0] / data.shape[0] # >>> 0.023


A pontuação de palavras válidas em inglês mostra que apenas 1,5% das descrições têm menos de 70% de palavras válidas em inglês entre palavras ASCII.

 w2v_eng = gensim.models.KeyedVectors.load_word2vec_format(w2v_path, binary=True) def get_valid_eng_score(description): description = re.sub("[^az \t]+", " ", description.lower()) total_word_cnt = 0 eng_word_cnt = 0 for word in description.split(): total_word_cnt += 1 if word.lower() in w2v_eng: eng_word_cnt += 1 return eng_word_cnt / total_word_cnt data["eng_score"] = data["description"].apply(get_valid_eng_score) data[data["eng_score"] < 0.7].shape[0] / data.shape[0] # >>> 0.015


Portanto, a maioria das descrições (cerca de 96%) está em inglês ou maioritariamente em inglês. Podemos remover todas as outras descrições, mas em vez disso, vamos deixá-las como estão e ver como cada modelo as trata.

Modelagem

Vamos dividir nosso conjunto de dados em 3 grupos:

  • Treinar 70% - para treinar os modelos (19 mil mensagens)

  • Teste 15% - para escolha de parâmetros e limites (4,1 mil mensagens)

  • Avaliação 15% - para escolha do modelo final (4,1 mil mensagens)


 from sklearn.model_selection import train_test_split data_train, data_test = train_test_split(data, test_size=0.3) data_test, data_eval = train_test_split(data_test, test_size=0.5) data_train.shape, data_test.shape, data_eval.shape # >>> ((19461, 3), (4170, 3), (4171, 3))


Modelo de linha de base: saco de palavras + regressão logística

É útil fazer algo simples e trivial no início para obter uma boa base. Como linha de base, vamos criar uma estrutura de saco de palavras com base no conjunto de dados do trem.


Vamos também limitar o tamanho do dicionário a 100 palavras.

 count_vectorizer = CountVectorizer(max_features=100, stop_words="english") x_train_baseline = count_vectorizer.fit_transform(data_train["description"]) y_train_baseline = data_train["category"] x_test_baseline = count_vectorizer.transform(data_test["description"]) y_test_baseline = data_test["category"] x_train_baseline = x_train_baseline.toarray() x_test_baseline = x_test_baseline.toarray()


Estou planejando usar a regressão logística como modelo, então preciso normalizar os recursos do contador antes do treinamento.

 ss = StandardScaler() x_train_baseline = ss.fit_transform(x_train_baseline) x_test_baseline = ss.transform(x_test_baseline) lr = LogisticRegression() lr.fit(x_train_baseline, y_train_baseline) balanced_accuracy_score(y_test_baseline, lr.predict(x_test_baseline)) # >>> 0.752


A regressão logística multiclasse apresentou acurácia balanceada de 75,2%. Esta é uma ótima base!


Embora a qualidade geral da classificação não seja excelente, o modelo ainda pode nos fornecer alguns insights. Vejamos a matriz de confusão, normalizada pelo número de rótulos previstos. O eixo X denota a categoria prevista e o eixo Y - a categoria real. Olhando para cada coluna podemos ver a distribuição das categorias reais quando uma determinada categoria foi prevista.


Matriz de confusão para solução de linha de base.


Por exemplo, Electronics é frequentemente confundida com Household . Mas mesmo este modelo simples pode capturar Clothing & Accessories com bastante precisão.


Aqui estão as importâncias dos recursos ao prever a categoria Clothing & Accessories :

Importâncias dos recursos para solução de linha de base para o rótulo 'Roupas e Acessórios'


As 6 palavras que mais contribuem a favor e contra a categoria Clothing & Accessories :

 women 1.49 book -2.03 men 0.93 table -1.47 cotton 0.92 author -1.11 wear 0.69 books -1.10 fit 0.40 led -0.90 stainless 0.36 cable -0.85


RNNs

Agora vamos considerar modelos mais avançados projetados especificamente para trabalhar com sequências - redes neurais recorrentes . GRU e LSTM são camadas avançadas comuns para combater a explosão de gradientes que ocorrem em RNNs simples.


Usaremos a biblioteca pytorch para tokenizar descrições e construir e treinar um modelo.


Primeiro, precisamos transformar textos em números:

  1. Divida as descrições em palavras
  2. Atribua um índice a cada palavra do corpus com base no conjunto de dados de treinamento
  3. Reserve índices especiais para palavras desconhecidas e preenchimento
  4. Transforme cada descrição em conjuntos de dados de treinamento e teste em vetores de índices.


O vocabulário que obtemos simplesmente tokenizando o conjunto de dados do trem é grande – quase 90 mil palavras. Quanto mais palavras tivermos, maior será o espaço de incorporação que o modelo terá que aprender. Para simplificar o treinamento, vamos retirar dele as palavras mais raras e deixar apenas aquelas que aparecem em pelo menos 3% das descrições. Isso truncará o vocabulário para 340 palavras.

(encontre a implementação completa CorpusDictionary aqui )


 corpus_dict = util.CorpusDictionary(data_train["description"]) corpus_dict.truncate_dictionary(min_frequency=0.03) data_train["vector"] = corpus_dict.transform(data_train["description"]) data_test["vector"] = corpus_dict.transform(data_test["description"]) print(data_train["vector"].head()) # 28453 [1, 1, 1, 1, 12, 1, 2, 1, 6, 1, 1, 1, 1, 1, 6,... # 48884 [1, 1, 13, 34, 3, 1, 1, 38, 12, 21, 2, 1, 37, ... # 36550 [1, 60, 61, 1, 62, 60, 61, 1, 1, 1, 1, 10, 1, ... # 34999 [1, 34, 1, 1, 75, 60, 61, 1, 1, 72, 1, 1, 67, ... # 19183 [1, 83, 1, 1, 87, 1, 1, 1, 12, 21, 42, 1, 2, 1... # Name: vector, dtype: object


A próxima coisa que precisamos decidir é o comprimento comum dos vetores que alimentaremos como entradas no RNN. Não queremos usar vetores completos, porque a descrição mais longa contém 9,4 mil tokens.


No entanto, 95% das descrições no conjunto de dados do trem não ultrapassam 352 tokens - esse é um bom comprimento para corte. O que acontecerá com descrições mais curtas?


Eles serão preenchidos com índice de preenchimento até o comprimento comum.

 print(max(data_train["vector"].apply(len))) # >>> 9388 print(int(np.quantile(data_train["vector"].apply(len), q=0.95))) # >>> 352


Em seguida - precisamos transformar as categorias alvo em vetores 0-1 para calcular a perda e realizar a retropropagação em cada etapa de treinamento.

 def get_target(label, total_labels=4): target = [0] * total_labels target[label_2_idx.get(label)] = 1 return target data_train["target"] = data_train["category"].apply(get_target) data_test["target"] = data_test["category"].apply(get_target)


Agora estamos prontos para criar um conjunto de dados e um dataloader pytorch personalizados para alimentar o modelo. Encontre a implementação completa PaddedTextVectorDataset aqui .

 ds_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], corpus_dict, max_vector_len=352, ) ds_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], corpus_dict, max_vector_len=352, ) train_dl = DataLoader(ds_train, batch_size=512, shuffle=True) test_dl = DataLoader(ds_test, batch_size=512, shuffle=False)


Finalmente, vamos construir um modelo.


A arquitetura mínima é:

  • camada de incorporação
  • Camada RNN
  • camada linear
  • camada de ativação


Começando com pequenos valores de parâmetros (tamanho do vetor de incorporação, tamanho de uma camada oculta no RNN, número de camadas RNN) e sem regularização, podemos gradualmente tornar o modelo mais complicado até que ele mostre fortes sinais de sobreajuste, e então equilibrar regularização (quedas na camada RNN e antes da última camada linear).


 class GRU(nn.Module): def __init__(self, vocab_size, embedding_dim, n_hidden, n_out): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.n_hidden = n_hidden self.n_out = n_out self.emb = nn.Embedding(self.vocab_size, self.embedding_dim) self.gru = nn.GRU(self.embedding_dim, self.n_hidden) self.dropout = nn.Dropout(0.3) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self._init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) gru_out, self.hidden = self.gru(embs, self.hidden) gru_out, lengths = pad_packed_sequence(gru_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def _init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


Usaremos o otimizador Adam e cross_entropy como uma função de perda.


 vocab_size = len(corpus_dict.word_to_idx) emb_dim = 4 n_hidden = 15 n_out = len(label_2_idx) model = GRU(vocab_size, emb_dim, n_hidden, n_out) opt = optim.Adam(model.parameters(), 1e-2) util.fit( model=model, train_dl=train_dl, test_dl=test_dl, loss_fn=F.cross_entropy, opt=opt, epochs=35 ) # >>> Train loss: 0.3783 # >>> Val loss: 0.4730 

Perdas de treinamento e teste por época, modelo RNN

Este modelo mostrou 84,3% de precisão balanceada no conjunto de dados de avaliação. Uau, que progresso!


Apresentando embeddings pré-treinados

A principal desvantagem de treinar o modelo RNN do zero é que ele precisa aprender o significado das próprias palavras - esse é o trabalho da camada de incorporação. Modelos word2vec pré-treinados estão disponíveis para uso como uma camada de incorporação pronta, o que reduz o número de parâmetros e adiciona muito mais significado aos tokens. Vamos usar um dos modelos word2vec disponíveis em pytorch - glove, dim=300 .


Precisamos apenas fazer pequenas alterações na criação do Dataset - agora queremos criar um vetor de glove pré-definidos para cada descrição e a arquitetura do modelo.

 ds_emb_train = util.PaddedTextVectorDataset( data_train["description"], data_train["target"], emb=glove, max_vector_len=max_len, ) ds_emb_test = util.PaddedTextVectorDataset( data_test["description"], data_test["target"], emb=glove, max_vector_len=max_len, ) dl_emb_train = DataLoader(ds_emb_train, batch_size=512, shuffle=True) dl_emb_test = DataLoader(ds_emb_test, batch_size=512, shuffle=False)
 import torchtext.vocab as vocab glove = vocab.GloVe(name='6B', dim=300) class LSTMPretrained(nn.Module): def __init__(self, n_hidden, n_out): super().__init__() self.emb = nn.Embedding.from_pretrained(glove.vectors) self.emb.requires_grad_ = False self.embedding_dim = 300 self.n_hidden = n_hidden self.n_out = n_out self.lstm = nn.LSTM(self.embedding_dim, self.n_hidden, num_layers=1) self.dropout = nn.Dropout(0.5) self.out = nn.Linear(self.n_hidden, self.n_out) def forward(self, sequence, lengths): batch_size = sequence.size(1) self.hidden = self.init_hidden(batch_size) embs = self.emb(sequence) embs = pack_padded_sequence(embs, lengths, enforce_sorted=True) lstm_out, (self.hidden, _) = self.lstm(embs) lstm_out, lengths = pad_packed_sequence(lstm_out) dropout = self.dropout(self.hidden[-1]) output = self.out(dropout) return F.log_softmax(output, dim=-1) def init_hidden(self, batch_size): return Variable(torch.zeros((1, batch_size, self.n_hidden)))


E estamos prontos para treinar!

 n_hidden = 50 n_out = len(label_2_idx) emb_model = LSTMPretrained(n_hidden, n_out) opt = optim.Adam(emb_model.parameters(), 1e-2) util.fit(model=emb_model, train_dl=dl_emb_train, test_dl=dl_emb_test, loss_fn=F.cross_entropy, opt=opt, epochs=11) 

Perdas de treinamento e teste por época, modelo RNN + embeddings pré-treinados

Agora estamos obtendo 93,7% de precisão balanceada no conjunto de dados de avaliação. Uau!


BERTO

Os modelos modernos de última geração para trabalhar com sequências são os transformadores. No entanto, para treinar um transformador do zero, precisaríamos de grandes quantidades de dados e recursos computacionais. O que podemos tentar aqui é ajustar um dos modelos pré-treinados para servir ao nosso propósito. Para fazer isso, precisamos baixar um modelo BERT pré-treinado e adicionar dropout e camada linear para obter a previsão final. Recomenda-se treinar um modelo ajustado por 4 épocas. Treinei apenas 2 épocas extras para economizar tempo - levei 40 minutos para fazer isso.


 from transformers import BertModel class BERTModel(nn.Module): def __init__(self, n_out=12): super(BERTModel, self).__init__() self.l1 = BertModel.from_pretrained('bert-base-uncased') self.l2 = nn.Dropout(0.3) self.l3 = nn.Linear(768, n_out) def forward(self, ids, mask, token_type_ids): output_1 = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) output_2 = self.l2(output_1.pooler_output) output = self.l3(output_2) return output


 ds_train_bert = bert.get_dataset( list(data_train["description"]), list(data_train["target"]), max_vector_len=64 ) ds_test_bert = bert.get_dataset( list(data_test["description"]), list(data_test["target"]), max_vector_len=64 ) dl_train_bert = DataLoader(ds_train_bert, sampler=RandomSampler(ds_train_bert), batch_size=batch_size) dl_test_bert = DataLoader(ds_test_bert, sampler=SequentialSampler(ds_test_bert), batch_size=batch_size)


 b_model = bert.BERTModel(n_out=4) b_model.to(torch.device("cpu")) def loss_fn(outputs, targets): return torch.nn.BCEWithLogitsLoss()(outputs, targets) optimizer = optim.AdamW(b_model.parameters(), lr=2e-5, eps=1e-8) epochs = 2 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps ) bert.fit(b_model, dl_train_bert, dl_test_bert, optimizer, scheduler, loss_fn, device, epochs=epochs) torch.save(b_model, "models/bert_fine_tuned")


Registro de treinamento:

 2024-02-29 19:38:13.383953 Epoch 1 / 2 Training... 2024-02-29 19:40:39.303002 step 40 / 305 done 2024-02-29 19:43:04.482043 step 80 / 305 done 2024-02-29 19:45:27.767488 step 120 / 305 done 2024-02-29 19:47:53.156420 step 160 / 305 done 2024-02-29 19:50:20.117272 step 200 / 305 done 2024-02-29 19:52:47.988203 step 240 / 305 done 2024-02-29 19:55:16.812437 step 280 / 305 done 2024-02-29 19:56:46.990367 Average training loss: 0.18 2024-02-29 19:56:46.990932 Validating... 2024-02-29 19:57:51.182859 Average validation loss: 0.10 2024-02-29 19:57:51.182948 Epoch 2 / 2 Training... 2024-02-29 20:00:25.110818 step 40 / 305 done 2024-02-29 20:02:56.240693 step 80 / 305 done 2024-02-29 20:05:25.647311 step 120 / 305 done 2024-02-29 20:07:53.668489 step 160 / 305 done 2024-02-29 20:10:33.936778 step 200 / 305 done 2024-02-29 20:13:03.217450 step 240 / 305 done 2024-02-29 20:15:28.384958 step 280 / 305 done 2024-02-29 20:16:57.004078 Average training loss: 0.08 2024-02-29 20:16:57.004657 Validating... 2024-02-29 20:18:01.546235 Average validation loss: 0.09


Finalmente, o modelo BERT ajustado mostra uma precisão impressionante de 95,1% no conjunto de dados de avaliação.


Escolhendo nosso vencedor

Já estabelecemos uma lista de considerações a serem observadas para fazer uma escolha final bem informada.

Aqui estão gráficos que mostram parâmetros mensuráveis:

Métricas de desempenho dos modelos


Embora o BERT ajustado seja líder em qualidade, o RNN com camada de incorporação pré-treinada LSTM+EMB está em segundo lugar, ficando atrás apenas em 3% das atribuições automáticas de categoria.


Por outro lado, o tempo de inferência do BERT ajustado é 14 vezes maior que LSTM+EMB . Isso aumentará os custos de manutenção de back-end, que provavelmente superarão os benefícios que BERT ajustado traz sobre LSTM+EMB .


Quanto à interoperabilidade, nosso modelo de regressão logística de base é de longe o mais interpretável e qualquer rede neural perde para ele nesse aspecto. Ao mesmo tempo, a linha de base é provavelmente a menos escalonável – adicionar categorias diminuirá a já baixa qualidade da linha de base.


Mesmo que o BERT pareça uma estrela com sua alta precisão, acabamos optando pelo RNN com uma camada de incorporação pré-treinada. Por que? É bastante preciso, não muito lento e não fica muito complicado de manusear quando as coisas ficam grandes.


Espero que você tenha gostado deste estudo de caso. Qual modelo você escolheria e por quê?