paint-brush
Cómo mejorar la paralelización de los cargadores de datos de Torch mediante el multiprocesamiento de Torch.por@pixelperfectionist
467 lecturas
467 lecturas

Cómo mejorar la paralelización de los cargadores de datos de Torch mediante el multiprocesamiento de Torch.

por Prerak Mody13m2024/06/10
Read on Terminal Reader

Demasiado Largo; Para Leer

El cargador de datos de PyTorch es una herramienta para cargar y preprocesar datos de manera eficiente para entrenar modelos de aprendizaje profundo. En esta publicación, exploramos cómo podemos acelerar este proceso usando nuestro cargador de datos personalizado junto con torch.multiprocessing. Experimentamos cargando múltiples cortes 2D a partir de un conjunto de datos de escaneos médicos 3D.
featured image - Cómo mejorar la paralelización de los cargadores de datos de Torch mediante el multiprocesamiento de Torch.
Prerak Mody HackerNoon profile picture
0-item

Introducción

DataLoader de PyTorch ( torch.utils.data.Dataloader ) ya es una herramienta útil para cargar y preprocesar datos de manera eficiente para entrenar modelos de aprendizaje profundo. De forma predeterminada, PyTorch utiliza un proceso de un solo trabajador ( num_workers=0 ), pero los usuarios pueden especificar un número mayor para aprovechar el paralelismo y acelerar la carga de datos.


Sin embargo, dado que es un cargador de datos de uso general y aunque ofrece paralelización, todavía no es adecuado para ciertos casos de uso personalizados. En esta publicación, exploramos cómo podemos acelerar la carga de múltiples cortes 2D de un conjunto de datos de escaneos médicos 3D usando torch.multiprocessing() .


Deseamos extraer un conjunto de cortes del escaneo 3D de cada paciente. Estos pacientes son parte de un gran conjunto de datos.



Nuestro torch.utils.data.Dataset

Me imagino un caso de uso en el que se proporciona un conjunto de escaneos 3D para pacientes (es decir, P1, P2, P3,…) y una lista de cortes correspondientes; Nuestro objetivo es construir un cargador de datos que genere un segmento en cada iteración . Verifique el código Python a continuación donde creamos un conjunto de datos de antorcha llamado myDataset y lo pasamos a torch.utils.data.Dataloader() .


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


La principal preocupación con nuestro caso de uso es que los escaneos médicos 3D son de gran tamaño ( emulados aquí por la operación time.sleep() ) y por lo tanto

  • leerlos desde el disco puede llevar mucho tiempo

  • y en la mayoría de los casos, un gran conjunto de datos de escaneos 3D no se puede leer previamente en la memoria.


Idealmente, solo deberíamos leer la exploración de cada paciente una vez para todos los cortes asociados a ella. Pero dado que torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) divide los datos en trabajadores según el tamaño del lote, existe la posibilidad de que diferentes trabajadores lean a un paciente dos veces ( verifique la imagen y registre abajo ).

Torch divide la carga del conjunto de datos en cada trabajador según el tamaño del lote (=3, en este caso). Debido a esto, cada paciente es leído por varios trabajadores.


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


Para resumir, estos son los problemas con la implementación existente de torch.utils.data.Dataloader

  • A cada uno de los trabajadores se le pasa una copia de myDataset() (Ref: antorcha v1.2. 0 ), y dado que no tienen memoria compartida, se produce una lectura doble del disco del escaneo 3D de un paciente.


  • Además, dado que la antorcha recorre secuencialmente patientSliceList ( consulte la imagen a continuación ), no es posible una mezcla natural entre combinaciones (patientId, sliceId). ( Nota: se puede mezclar, pero eso implica almacenar las salidas en la memoria )


El estándar torch.utils.data.Dataloader() tiene una cola interna que gestiona globalmente cómo se extraen los resultados de los trabajadores. Incluso si un trabajador en particular prepara los datos, no puede generarlos ya que debe respetar esta cola global.



Nota: También se podrían juntar un montón de cortes del escaneo 3D de cada paciente. Pero si también deseamos devolver matrices 3D dependientes de sectores (por ejemplo, redes de refinamiento interactivas ( consulte la Figura 1 de este trabajo ), esto aumenta en gran medida la huella de memoria de su cargador de datos.



Usando torch.multiprocessing

Para evitar lecturas múltiples de escaneos de pacientes , lo ideal sería que cada paciente ( imaginemos 8 pacientes ) fuera leído por un trabajador en particular.

Aquí, cada trabajador se concentra en leer a un (conjunto de) paciente(s).


Para lograr esto, utilizamos las mismas herramientas internas que la clase de carga de datos de torch (es decir, torch.multiprocessing() ), pero con una ligera diferencia. Consulte la figura del flujo de trabajo y el código a continuación para nuestro cargador de datos personalizado: myDataloader

Aquí, la cola de salida (abajo) contiene las salidas de cada trabajador. Cada trabajador recibe información de entrada (la cola de entrada se muestra en la parte superior) solo para un conjunto específico de pacientes. Por lo tanto, esto evita lecturas múltiples del escaneo 3D de un paciente.



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


El fragmento anterior ( con 8 pacientes en su lugar ) contiene las siguientes funciones

  • __iter__() - Dado que myDataloader() es un bucle, esta es la función sobre la que realmente se repite.


  • _initWorkers() : aquí creamos nuestros procesos de trabajo con sus colas de entrada individuales workerInputQueues[workerId] . Esto se llama cuando se inicializa la clase.


  • fillInputQueues() : esta función se llama cuando comenzamos el ciclo ( esencialmente al comienzo de cada época ). Llena la cola de entrada del trabajador individual.


  • getSlice() : esta es la función lógica principal que devuelve un segmento de un volumen de paciente. Consulta el código aquí .


  • collate_tensor_fn() : esta función se copia directamente desde el repositorio de torch: torchv1.12.0 y se utiliza para agrupar datos.


Actuación

Para probar si nuestro cargador de datos ofrece una aceleración en comparación con la opción predeterminada, probamos la velocidad de cada bucle del cargador de datos utilizando diferentes recuentos de trabajadores . Variamos dos parámetros en nuestros experimentos:


  • Numero de trabajadores : Probamos procesos de 1, 2, 4 y 8 trabajadores.
  • Tamaño del lote : Evaluamos diferentes tamaños de lote que van del 1 al 8.

Conjunto de datos de juguetes

Primero experimentamos con nuestro conjunto de datos de juguetes y vemos que nuestro cargador de datos funciona mucho más rápido. Vea la figura a continuación (o reprodúzcala con este código )
Un tiempo total más bajo y más iteraciones por segundo significan un mejor cargador de datos.

Aquí podemos ver lo siguiente

  • Cuando se utiliza un solo trabajador, ambos cargadores de datos son iguales.


  • Cuando se utilizan trabajadores adicionales (es decir, 2,4,8), hay una aceleración en ambos cargadores de datos; sin embargo, la aceleración es mucho mayor en nuestro cargador de datos personalizado.


  • Cuando se utiliza un tamaño de lote de 6 (en comparación con 1,2,3,4), hay un pequeño impacto en el rendimiento. Esto se debe a que, en nuestro conjunto de datos de juguetes, la patientSlicesList contiene 5 cortes por paciente. Por lo tanto, el trabajador debe esperar a leer el segundo paciente para agregarlo al último índice del lote.

Conjunto de datos del mundo real

Luego comparamos un conjunto de datos real donde se cargan escaneos 3D, se extrae un segmento, se realiza algún preprocesamiento adicional y luego se devuelven el segmento y otras matrices. Consulte la figura a continuación para ver los resultados.


Observamos que aumentar la cantidad de procesos de trabajadores (y tamaños de lotes) generalmente condujo a una carga de datos más rápida y por lo tanto puede conducir a un entrenamiento más rápido. Para lotes más pequeños (por ejemplo, 1 o 2), duplicar el número de trabajadores dio lugar a aceleraciones mucho mayores. Sin embargo, a medida que aumentaba el tamaño del lote, disminuía la mejora marginal derivada de la incorporación de más trabajadores.

Cuanto mayores sean las iteraciones por segundo, más rápido será el cargador de datos.

Utilización de recursos

También monitoreamos la utilización de recursos durante la carga de datos con diferentes recuentos de trabajadores. Con una mayor cantidad de trabajadores, observamos un mayor uso de CPU y memoria, lo cual es de esperarse debido al paralelismo introducido por procesos adicionales. Los usuarios deben considerar sus limitaciones de hardware y disponibilidad de recursos al elegir el número óptimo de trabajadores.

Resumen

  1. En esta publicación de blog, exploramos las limitaciones del DataLoader estándar de PyTorch cuando se trata de conjuntos de datos que contienen grandes escaneos médicos en 3D y presentamos una solución personalizada que utiliza torch.multiprocessing para mejorar la eficiencia de la carga de datos.


  2. En el contexto de la extracción de cortes de estos escaneos médicos 3D, el cargador de datos predeterminado puede generar múltiples lecturas del mismo escaneo del paciente, ya que los trabajadores no comparten memoria. Esta redundancia provoca retrasos importantes, especialmente cuando se trata de grandes conjuntos de datos.


  3. Nuestro cargador de datos personalizado divide a los pacientes entre trabajadores, asegurando que cada escaneo 3D se lea solo una vez por trabajador. Este enfoque evita lecturas de disco redundantes y aprovecha el procesamiento paralelo para acelerar la carga de datos.


  4. Las pruebas de rendimiento mostraron que nuestro cargador de datos personalizado generalmente supera al cargador de datos estándar, especialmente con lotes más pequeños y múltiples procesos de trabajo.


    1. Sin embargo, las mejoras en el rendimiento disminuyeron con lotes de mayor tamaño.


Nuestro dataLoader personalizado mejora la eficiencia de la carga de datos para grandes conjuntos de datos médicos en 3D al reducir las lecturas redundantes y maximizar el paralelismo. Esta mejora puede conducir a tiempos de capacitación más rápidos y una mejor utilización de los recursos de hardware.


Este blog fue escrito junto con mi colega Jingnan Jia .