Le DataLoader de PyTorch ( torch.utils.data.Dataloader
) est déjà un outil utile pour charger et prétraiter efficacement les données pour la formation de modèles d'apprentissage en profondeur. Par défaut, PyTorch utilise un processus à un seul travailleur ( num_workers=0
), mais les utilisateurs peuvent spécifier un nombre plus élevé pour tirer parti du parallélisme et accélérer le chargement des données.
Cependant, s’agissant d’un chargeur de données à usage général, et même s’il propose la parallélisation, il n’est toujours pas adapté à certains cas d’usage personnalisés. Dans cet article, nous explorons comment accélérer le chargement de plusieurs tranches 2D à partir d'un ensemble de données d'analyses médicales 3D à l'aide torch.multiprocessing()
.
torch.utils.data.Dataset
J'imagine un cas d'utilisation dans lequel, étant donné un ensemble de scans 3D pour des patients (c'est-à-dire P1, P2, P3, …) et une liste de tranches correspondantes ; notre objectif est de créer un chargeur de données qui génère une tranche à chaque itération . Vérifiez le code Python ci-dessous où nous construisons un ensemble de données torch appelé myDataset
et transmettez-le dans torch.utils.data.Dataloader()
.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
La principale préoccupation de notre cas d'utilisation est que les scans médicaux 3D sont de grande taille ( imulés ici par l' opération time.sleep()
) et donc
les lire à partir du disque peut prendre beaucoup de temps
et un grand ensemble de données de numérisations 3D ne peut dans la plupart des cas pas être prélu en mémoire
Idéalement, nous ne devrions lire chaque scan d’un patient qu’une seule fois pour toutes les coupes qui lui sont associées. Mais comme les données sont divisées par torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
en travailleurs en fonction de la taille du lot, il est possible pour différents travailleurs de lire un patient deux fois ( vérifiez l'image et enregistrez-le). ci-dessous ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Pour résumer, voici les problèmes liés à l'implémentation existante de torch.utils.data.Dataloader
myDataset()
(Réf :
patientSliceList
( voir image ci-dessous ), aucun mélange naturel n'est possible entre les combos (patientId, sliceId). ( Remarque : on peut mélanger, mais cela implique de stocker les sorties en mémoire )
Remarque : Il est également possible de renvoyer ensemble un ensemble de tranches issues du scan 3D de chaque patient. Mais si nous souhaitons également renvoyer des tableaux 3D dépendants des tranches (par exemple, des réseaux de raffinement interactifs ( voir Fig1 de ce travail ), alors cela augmente considérablement l'empreinte mémoire de votre chargeur de données.
torch.multiprocessing
Pour éviter les lectures multiples des scans des patients , nous aurions idéalement besoin que chaque patient ( imaginons 8 patients ) soit lu par un travailleur particulier.
Pour y parvenir, nous utilisons les mêmes outils internes que la classe torch dataloader (c'est-à-dire torch.multiprocessing()
) mais avec une légère différence. Vérifiez la figure et le code du flux de travail ci-dessous pour notre chargeur de données personnalisé - myDataloader
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
L'extrait ci-dessus ( avec 8 patients à la place ) contient les fonctions suivantes
__iter__()
- Puisque myDataloader()
est une boucle, c'est la fonction sur laquelle elle boucle.
_initWorkers()
- Ici, nous créons nos processus de travail avec leurs files d'attente d'entrée individuelles workerInputQueues[workerId]
. Ceci est appelé lorsque la classe est initialisée.
fillInputQueues()
- Cette fonction est appelée lorsque nous commençons la boucle ( essentiellement au début de chaque époque ). Il remplit la file d'attente d'entrée de chaque travailleur.
getSlice()
- Il s'agit de la fonction logique principale qui renvoie une tranche à partir d'un volume patient. Vérifiez le code ici .
collate_tensor_fn()
- Cette fonction est directement copiée à partir du dépôt torch - torchv1.12.0 et est utilisée pour regrouper les données par lots.Pour tester si notre chargeur de données offre une accélération par rapport à l'option par défaut, nous testons la vitesse de chaque boucle du chargeur de données en utilisant différents nombres de travailleurs . Nous avons fait varier deux paramètres dans nos expériences :
Nous expérimentons d'abord avec notre ensemble de données de jouets et constatons que notre chargeur de données fonctionne beaucoup plus rapidement. Voir la figure ci-dessous (ou reproduire avec ce code )
Ici, nous pouvons voir ce qui suit
patientSlicesList
contient 5 tranches par patient. Ainsi, le travailleur doit attendre de lire le deuxième patient pour l'ajouter au dernier index du lot. Nous comparons ensuite un ensemble de données réel où des scans 3D sont chargés, une tranche est extraite,
Nous avons observé que
Nous avons également surveillé l'utilisation des ressources pendant le chargement des données avec différents nombres de travailleurs. Avec un nombre de travailleurs plus élevé, nous avons observé une utilisation accrue du processeur et de la mémoire, ce qui est attendu en raison du parallélisme introduit par des processus supplémentaires. Les utilisateurs doivent tenir compte de leurs contraintes matérielles et de la disponibilité des ressources lors du choix du nombre optimal de travailleurs.
Dans cet article de blog, nous avons exploré les limites du DataLoader standard de PyTorch lorsqu'il s'agit d'ensembles de données contenant de grandes analyses médicales 3D et présenté une solution personnalisée utilisant torch.multiprocessing
pour améliorer l'efficacité du chargement des données.
Dans le contexte de l'extraction de tranches à partir de ces scans médicaux 3D, le dataLoader par défaut peut potentiellement conduire à plusieurs lectures du même scan d'un patient, car les travailleurs ne partagent pas de mémoire. Cette redondance entraîne des retards importants, notamment lorsqu'il s'agit de grands ensembles de données.
Notre dataLoader personnalisé répartit les patients entre les travailleurs, garantissant que chaque scan 3D n'est lu qu'une seule fois par travailleur. Cette approche évite les lectures de disque redondantes et exploite le traitement parallèle pour accélérer le chargement des données.
Les tests de performances ont montré que notre dataLoader personnalisé surpasse généralement le dataLoader standard, en particulier avec des lots plus petits et des processus de travail multiples.
Notre dataLoader personnalisé améliore l'efficacité du chargement des données pour les grands ensembles de données médicales 3D en réduisant les lectures redondantes et en maximisant le parallélisme. Cette amélioration peut conduire à des temps de formation plus rapides et à une meilleure utilisation des ressources matérielles.
Ce blog a été écrit avec mon collègue Jingnan Jia .