paint-brush
Neugierig auf schnellere ML-Modelle? Entdecken Sie die Modellquantisierung mit PyTorch!von@chinmayjog
416 Lesungen
416 Lesungen

Neugierig auf schnellere ML-Modelle? Entdecken Sie die Modellquantisierung mit PyTorch!

von Chinmay Jog12m2024/02/08
Read on Terminal Reader

Zu lang; Lesen

Erfahren Sie, wie Quantisierung mit Pytorch dazu beitragen kann, dass Ihre trainierten Modelle etwa viermal schneller laufen und gleichzeitig die Genauigkeit erhalten bleibt.
featured image - Neugierig auf schnellere ML-Modelle? Entdecken Sie die Modellquantisierung mit PyTorch!
Chinmay Jog HackerNoon profile picture
0-item
1-item


Wussten Sie, dass in der Welt des maschinellen Lernens die Effizienz von Deep-Learning-Modellen (DL) durch eine Technik namens Quantisierung deutlich gesteigert werden kann? Stellen Sie sich vor, Sie reduzieren den Rechenaufwand Ihres neuronalen Netzwerks, ohne dessen Leistung zu beeinträchtigen. Genau wie das Komprimieren einer großen Datei, ohne deren Wesentliches zu verlieren, ermöglicht Ihnen die Modellquantisierung, Ihre Modelle kleiner und schneller zu machen. Lassen Sie uns in das faszinierende Konzept der Quantisierung eintauchen und die Geheimnisse der Optimierung Ihrer neuronalen Netze für den Einsatz in der Praxis enthüllen.


Bevor wir näher darauf eingehen, sollten die Leser mit neuronalen Netzen und dem Grundkonzept der Quantisierung vertraut sein, einschließlich der Begriffe Skala (S) und Nullpunkt (ZP). Für Leser, die eine Auffrischung wünschen, werden in diesem Artikel und in diesem Artikel das allgemeine Konzept und die Arten der Quantisierung erläutert.


In diesem Leitfaden werde ich kurz erklären, warum Quantisierung wichtig ist und wie man sie mit Pytorch implementiert. Ich werde mich hauptsächlich auf die Art der Quantisierung konzentrieren, die als „statische Quantisierung nach dem Training“ bezeichnet wird und zu einem viermal geringeren Speicherbedarf des ML-Modells führt und die Inferenz bis zu viermal schneller macht.

Konzepte

Warum ist Quantisierung wichtig?

Berechnungen neuronaler Netze werden am häufigsten mit 32-Bit-Gleitkommazahlen durchgeführt. Eine einzelne 32-Bit-Gleitkommazahl (FP32) benötigt 4 Byte Speicher. Im Vergleich dazu benötigt eine einzelne 8-Bit-Ganzzahl (INT8) nur 1 Byte Speicher. Darüber hinaus verarbeiten Computer Ganzzahl-Arithmetik viel schneller als Float-Operationen. Sie können sofort erkennen, dass die Quantisierung eines ML-Modells von FP32 auf INT8 zu viermal weniger Speicher führt. Darüber hinaus wird die Schlussfolgerung um das Vierfache beschleunigt! Da große Modelle derzeit in aller Munde sind, ist es für Praktiker wichtig, trainierte Modelle hinsichtlich Speicher und Geschwindigkeit für Echtzeit-Inferenzen optimieren zu können.


Quelle: Tenor.com


Schlüsselbegriffe

  • Gewichte – Gewichte des trainierten neuronalen Netzwerks.


  • Aktivierungen – In Bezug auf die Quantisierung sind Aktivierungen nicht die Aktivierungsfunktionen wie Sigmoid oder ReLU. Mit Aktivierungen meine ich die Feature-Map-Ausgaben der Zwischenebenen, die Eingaben für die nächsten Ebenen sind.


Statische Quantisierung nach dem Training

Die statische Quantisierung nach dem Training bedeutet, dass wir das Modell nach dem Training des ursprünglichen Modells nicht für die Quantisierung trainieren oder verfeinern müssen. Wir müssen auch keine Zwischenschichteingaben, sogenannte Aktivierungen im laufenden Betrieb, quantisieren. Bei diesem Quantisierungsmodus werden die Gewichte direkt quantisiert, indem die Skala und der Nullpunkt für jede Schicht berechnet werden. Bei Aktivierungen ändern sich jedoch auch die Aktivierungen, wenn sich die Eingabe in das Modell ändert. Wir kennen nicht den Bereich jeder einzelnen Eingabe, auf die das Modell während der Inferenz stößt. Wie können wir also die Skala und den Nullpunkt für alle Aktivierungen des Netzwerks berechnen?


Wir können dies erreichen, indem wir das Modell mithilfe eines guten repräsentativen Datensatzes kalibrieren. Anschließend beobachten wir den Wertebereich der Aktivierungen für den Kalibrierungssatz und verwenden diese Statistiken dann zur Berechnung der Skalierung und des Nullpunkts. Dies geschieht durch das Einfügen von Beobachtern in das Modell, die während der Kalibrierung Datenstatistiken sammeln. Nach der Vorbereitung des Modells (Einfügen von Beobachtern) führen wir den Vorwärtsdurchlauf des Modells für den Kalibrierungsdatensatz durch. Die Beobachter verwenden diese Kalibrierungsdaten, um Maßstab und Nullpunkt für Aktivierungen zu berechnen. Bei der Schlussfolgerung geht es nun nur noch darum, die lineare Transformation auf alle Ebenen mit ihrer jeweiligen Skalierung und ihren Nullpunkten anzuwenden.

Während die gesamte Inferenz in INT8 erfolgt, wird die endgültige Modellausgabe dequantisiert (von INT8 zu FP32).


Warum müssen Aktivierungen quantisiert werden, wenn die Eingabe- und Netzwerkgewichte bereits quantisiert sind?

Das ist eine ausgezeichnete Frage. Während die Netzwerkeingabe und die Gewichte tatsächlich bereits INT8-Werte sind, wird die Ausgabe der Schicht als INT32 gespeichert, um einen Überlauf zu vermeiden. Um die Komplexität bei der Verarbeitung der nächsten Schicht zu reduzieren, werden die Aktivierungen von INT32 bis INT8 quantisiert.


Wenn die Konzepte klar sind, tauchen wir in den Code ein und sehen, wie er funktioniert!


Für dieses Beispiel verwende ich ein resnet18-Modell, das auf den Flowers102-Datensatz abgestimmt ist und direkt in Pytorch verfügbar ist. Der Code funktioniert jedoch für jedes trainierte CNN mit dem entsprechenden Kalibrierungsdatensatz. Da sich dieses Tutorial auf die Quantisierung konzentriert, werde ich nicht auf den Schulungs- und Feinabstimmungsteil eingehen. Der gesamte Code ist jedoch hier zu finden. Lasst uns eintauchen!


Quantisierungscode

Lassen Sie uns die für die Quantisierung erforderlichen Bibliotheken importieren und das fein abgestimmte Modell laden.

 import torch import torchvision import torchvision.transforms as transforms from torchvision.models import resnet18 import torch.nn as nn from torch.ao.quantization import get_default_qconfig from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx from torch.ao.quantization import QConfigMapping import warnings warnings.filterwarnings('ignore')


Als nächstes definieren wir einige Parameter, definieren Datentransformationen und Datenlader und laden das fein abgestimmte Modell

 model_path = 'flowers_model.pth' quantized_model_save_path = 'quantized_flowers_model.pth' batch_size = 10 num_classes = 102 # Define data transforms transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( (0.485, 0.465, 0.406), (0.229, 0.224, 0.225))] ) # Define train data loader, for using as calibration set trainset = torchvision.datasets.Flowers102(root='./data', split="train", download=True, transform=transform) trainLoader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) # Load the finetuned resnet model model_to_quantize = resnet18(weights=None) num_features = model_to_quantize.fc.in_features model_to_quantize.fc = nn.Linear(num_features, num_classes) model_to_quantize.load_state_dict(torch.load(model_path)) model_to_quantize.eval() print('Loaded fine-tuned model')

Für dieses Beispiel verwende ich einige Trainingsbeispiele als Kalibrierungssatz.

Definieren wir nun die Konfiguration, die zur Quantisierung des Modells verwendet wird.

 # Define quantization parameters config for the correct platform, # "x86" for x86 devices or "qnnpack" for arm devices qconfig = get_default_qconfig("x86") qconfig_mapping = QConfigMapping().set_global(qconfig)

Im obigen Snippet habe ich die Standardkonfiguration verwendet, aber die QConfig- Klasse von Pytorch wird verwendet, um zu beschreiben, wie das Modell oder ein Teil des Modells quantisiert werden soll. Wir können dies tun, indem wir die Art der Beobachterklassen angeben, die für Gewichtungen und Aktivierungen verwendet werden sollen.


Jetzt können wir das Modell für die Quantisierung vorbereiten

 # Fuse conv-> relu, conv -> bn -> relu layer blocks and insert observers model_prep = prepare_fx(model=model_to_quantize, qconfig_mapping=qconfig_mapping, example_inputs=torch.randn((1,3,224,224)))

Die prepare_fx Funktion fügt die Beobachter in das Modell ein und verschmilzt außerdem die Module conv→relu und conv→bn→relu. Dies führt zu weniger Operationen und einer geringeren Speicherbandbreite, da keine Zwischenergebnisse dieser Module gespeichert werden müssen.


Kalibrieren Sie das Modell, indem Sie die Kalibrierungsdaten vorwärts weiterleiten

 # Run calibration for 10 batches (100 random samples in total) print('Running calibration') with torch.no_grad(): for i, data in enumerate(trainLoader): samples, labels = data _ = model_prep(samples) if i == 10: break

Wir müssen nicht den gesamten Trainingssatz kalibrieren! In diesem Beispiel verwende ich 100 Zufallsstichproben, aber in der Praxis sollten Sie einen Datensatz wählen, der repräsentativ für das ist, was das Modell während der Bereitstellung sehen wird.


Quantisieren Sie das Modell und speichern Sie die quantisierten Gewichte!

 # Quantize calibrated model quantized_model = convert_fx(model_prep) print('Quantized model!') # Save quantized torch.save(quantized_model.state_dict(), quantized_model_save_path) print('Saved quantized model weights to disk')

Und das ist es! Sehen wir uns nun an, wie ein quantisiertes Modell geladen wird, und vergleichen wir dann die Genauigkeit, Geschwindigkeit und den Speicherbedarf des ursprünglichen und des quantisierten Modells.


Laden Sie ein quantisiertes Modell

Ein quantisiertes Modelldiagramm ist nicht ganz dasselbe wie das Originalmodell, auch wenn beide die gleichen Schichten haben.

Das Drucken der ersten Ebene ( conv1 ) beider Modelle zeigt den Unterschied.

 print('\nPrinting conv1 layer of fp32 and quantized model') print(f'fp32 model: {model_to_quantize.conv1}') print(f'quantized model: {quantized_model.conv1}') 

1. Schicht des fp32-Modells und des quantisierten Modells


Sie werden feststellen, dass die Ebene „conv1“ des quantisierten Modells neben der anderen Klasse auch die Parameter „Skalierung“ und „Nullpunkt“ enthält.


Wir müssen also den Quantisierungsprozess (ohne Kalibrierung) verfolgen, um den Modellgraphen zu erstellen, und dann die quantisierten Gewichte laden. Wenn wir das quantisierte Modell im ONNX-Format speichern, können wir es natürlich wie jedes andere ONNX-Modell laden, ohne jedes Mal die Quantisierungsfunktionen ausführen zu müssen.

In der Zwischenzeit definieren wir eine Funktion zum Laden des quantisierten Modells und speichern es in inference_utils.py .

 import torch from torch.ao.quantization import get_default_qconfig from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx from torch.ao.quantization import QConfigMapping def load_quantized_model(model_to_quantize, weights_path): ''' Model only needs to be calibrated for the first time. Next time onwards, to load the quantized model, you still need to prepare and convert the model without calibrating it. After that, load the state dict as usual. ''' model_to_quantize.eval() qconfig = get_default_qconfig("x86") qconfig_mapping = QConfigMapping().set_global(qconfig) model_prep = prepare_fx(model_to_quantize, qconfig_mapping, torch.randn((1,3,224,224))) quantized_model = convert_fx(model_prep) quantized_model.load_state_dict(torch.load(weights_path)) return quantized_model


Definieren Sie Funktionen zur Messung von Genauigkeit und Geschwindigkeit

Genauigkeit messen

 import torch def test_accuracy(model, testLoader): model.eval() running_acc = 0 num_samples = 0 with torch.no_grad(): for i, data in enumerate(testLoader): samples, labels = data outputs = model(samples) preds = torch.argmax(outputs, 1) running_acc += torch.sum(preds == labels) num_samples += samples.size(0) return running_acc / num_samples

Dies ist ein ziemlich einfacher Pytorch-Code.


Inferenzgeschwindigkeit in Millisekunden (ms) messen

 import torch from time import time def test_speed(model): dummy_sample = torch.randn((1,3,224,224)) # Average out inference speed over multiple iterations # to get a true estimate num_iterations = 100 start = time() for _ in range(num_iterations): _ = model(dummy_sample) end = time() return (end-start)/num_iterations * 1000


Fügen Sie diese beiden Funktionen in inference_utils.py hinzu. Wir sind jetzt bereit, Modelle zu vergleichen. Lassen Sie uns den Code durchgehen.


Vergleichen Sie Modelle hinsichtlich Genauigkeit, Geschwindigkeit und Größe

Lassen Sie uns zunächst die erforderlichen Bibliotheken importieren, Parameter, Datentransformationen und den Testdatenlader definieren.

 import os import torch import torch.nn as nn import torchvision from torchvision.models import resnet18 import torchvision.transforms as transforms from inference_utils import test_accuracy, test_speed, load_quantized_model import copy import warnings warnings.filterwarnings('ignore') model_weights_path = 'flowers_model.pth' quantized_model_weights_path = 'quantized_flowers_model.pth' batch_size = 10 num_classes = 102 # Define data transforms transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( (0.485, 0.465, 0.406), (0.229, 0.224, 0.225))] ) testset = torchvision.datasets.Flowers102(root='./data', split="test", download=True, transform=transform) testLoader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


Laden Sie die beiden Modelle

 # Load the finetuned resnet model and the quantized model model = resnet18(weights=None) num_features = model.fc.in_features model.fc = nn.Linear(num_features, num_classes) model.load_state_dict(torch.load(model_weights_path)) model.eval() model_to_quantize = copy.deepcopy(model) quantized_model = load_quantized_model(model_to_quantize, quantized_model_weights_path)


Vergleichen Sie Modelle

 # Compare accuracy fp32_accuracy = test_accuracy(model, testLoader) accuracy = test_accuracy(quantized_model, testLoader) print(f'Original model accuracy: {fp32_accuracy:.3f}') print(f'Quantized model accuracy: {accuracy:.3f}\n') # Compare speed fp32_speed = test_speed(model) quantized_speed = test_speed(quantized_model) print(f'Inference time for original model: {fp32_speed:.3f} ms') print(f'Inference time for quantized model: {quantized_speed:.3f} ms\n') # Compare file size fp32_size = os.path.getsize(model_weights_path)/10**6 quantized_size = os.path.getsize(quantized_model_weights_path)/10**6 print(f'Original model file size: {fp32_size:.3f} MB') print(f'Quantized model file size: {quantized_size:.3f} MB')


Ergebnisse

Vergleich von fp32 und quantisiertem Modell


Wie Sie sehen, ist die Genauigkeit des quantisierten Modells auf den Testdaten fast so hoch wie die Genauigkeit des Originalmodells! Die Inferenz mit dem quantisierten Modell ist ~3,6x schneller (!) und das quantisierte Modell benötigt ~4x weniger Speicher als das Originalmodell!


Abschluss

In diesem Artikel haben wir das umfassende Konzept der Quantisierung von ML-Modellen und eine Art der Quantisierung namens Post Training Static Quantization verstanden. Wir haben auch untersucht, warum Quantisierung im Zeitalter großer Modelle wichtig und ein leistungsstarkes Werkzeug ist. Schließlich gingen wir Beispielcode durch, um ein trainiertes Modell mit Pytorch zu quantisieren, und überprüften die Ergebnisse. Wie die Ergebnisse zeigten, hatte die Quantisierung des Originalmodells keinen Einfluss auf die Leistung und verringerte gleichzeitig die Inferenzgeschwindigkeit um das etwa 3,6-fache und den Speicherbedarf um das etwa 4-fache!


Ein paar Punkte sind zu beachten: Statische Quantisierung funktioniert gut für CNNs, dynamische Quantisierung ist jedoch die bevorzugte Methode für Sequenzmodelle. Wenn sich die Quantisierung außerdem drastisch auf die Modellleistung auswirkt, kann die Genauigkeit durch eine Technik namens Quantization Aware Training (QAT) wiederhergestellt werden.


Wie funktionieren dynamische Quantisierung und QAT? Das sind Beiträge für ein anderes Mal. Ich hoffe, dass Ihnen dieser Leitfaden das nötige Wissen vermittelt, um eine statische Quantisierung an Ihren eigenen Pytorch-Modellen durchzuführen.


Verweise