DataLoader de PyTorch ( torch.utils.data.Dataloader
) ya es una herramienta útil para cargar y preprocesar datos de manera eficiente para entrenar modelos de aprendizaje profundo. De forma predeterminada, PyTorch utiliza un proceso de un solo trabajador ( num_workers=0
), pero los usuarios pueden especificar un número mayor para aprovechar el paralelismo y acelerar la carga de datos.
Sin embargo, dado que es un cargador de datos de uso general y aunque ofrece paralelización, todavía no es adecuado para ciertos casos de uso personalizados. En esta publicación, exploramos cómo podemos acelerar la carga de múltiples cortes 2D de un conjunto de datos de escaneos médicos 3D usando torch.multiprocessing()
.
torch.utils.data.Dataset
Me imagino un caso de uso en el que se proporciona un conjunto de escaneos 3D para pacientes (es decir, P1, P2, P3,…) y una lista de cortes correspondientes; Nuestro objetivo es construir un cargador de datos que genere un segmento en cada iteración . Verifique el código Python a continuación donde creamos un conjunto de datos de antorcha llamado myDataset
y lo pasamos a torch.utils.data.Dataloader()
.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
La principal preocupación con nuestro caso de uso es que los escaneos médicos 3D son de gran tamaño ( emulados aquí por la operación time.sleep()
) y por lo tanto
leerlos desde el disco puede llevar mucho tiempo
y en la mayoría de los casos, un gran conjunto de datos de escaneos 3D no se puede leer previamente en la memoria.
Idealmente, solo deberíamos leer la exploración de cada paciente una vez para todos los cortes asociados a ella. Pero dado que torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
divide los datos en trabajadores según el tamaño del lote, existe la posibilidad de que diferentes trabajadores lean a un paciente dos veces ( verifique la imagen y registre abajo ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Para resumir, estos son los problemas con la implementación existente de torch.utils.data.Dataloader
myDataset()
(Ref:
patientSliceList
( consulte la imagen a continuación ), no es posible una mezcla natural entre combinaciones (patientId, sliceId). ( Nota: se puede mezclar, pero eso implica almacenar las salidas en la memoria )
Nota: También se podrían juntar un montón de cortes del escaneo 3D de cada paciente. Pero si también deseamos devolver matrices 3D dependientes de sectores (por ejemplo, redes de refinamiento interactivas ( consulte la Figura 1 de este trabajo ), esto aumenta en gran medida la huella de memoria de su cargador de datos.
torch.multiprocessing
Para evitar lecturas múltiples de escaneos de pacientes , lo ideal sería que cada paciente ( imaginemos 8 pacientes ) fuera leído por un trabajador en particular.
Para lograr esto, utilizamos las mismas herramientas internas que la clase de carga de datos de torch (es decir, torch.multiprocessing()
), pero con una ligera diferencia. Consulte la figura del flujo de trabajo y el código a continuación para nuestro cargador de datos personalizado: myDataloader
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
El fragmento anterior ( con 8 pacientes en su lugar ) contiene las siguientes funciones
__iter__()
- Dado que myDataloader()
es un bucle, esta es la función sobre la que realmente se repite.
_initWorkers()
: aquí creamos nuestros procesos de trabajo con sus colas de entrada individuales workerInputQueues[workerId]
. Esto se llama cuando se inicializa la clase.
fillInputQueues()
: esta función se llama cuando comenzamos el ciclo ( esencialmente al comienzo de cada época ). Llena la cola de entrada del trabajador individual.
getSlice()
: esta es la función lógica principal que devuelve un segmento de un volumen de paciente. Consulta el código aquí .
collate_tensor_fn()
: esta función se copia directamente desde el repositorio de torch: torchv1.12.0 y se utiliza para agrupar datos.Para probar si nuestro cargador de datos ofrece una aceleración en comparación con la opción predeterminada, probamos la velocidad de cada bucle del cargador de datos utilizando diferentes recuentos de trabajadores . Variamos dos parámetros en nuestros experimentos:
Primero experimentamos con nuestro conjunto de datos de juguetes y vemos que nuestro cargador de datos funciona mucho más rápido. Vea la figura a continuación (o reprodúzcala con este código )
Aquí podemos ver lo siguiente
patientSlicesList
contiene 5 cortes por paciente. Por lo tanto, el trabajador debe esperar a leer el segundo paciente para agregarlo al último índice del lote. Luego comparamos un conjunto de datos real donde se cargan escaneos 3D, se extrae un segmento,
Observamos que
También monitoreamos la utilización de recursos durante la carga de datos con diferentes recuentos de trabajadores. Con una mayor cantidad de trabajadores, observamos un mayor uso de CPU y memoria, lo cual es de esperarse debido al paralelismo introducido por procesos adicionales. Los usuarios deben considerar sus limitaciones de hardware y disponibilidad de recursos al elegir el número óptimo de trabajadores.
En esta publicación de blog, exploramos las limitaciones del DataLoader estándar de PyTorch cuando se trata de conjuntos de datos que contienen grandes escaneos médicos en 3D y presentamos una solución personalizada que utiliza torch.multiprocessing
para mejorar la eficiencia de la carga de datos.
En el contexto de la extracción de cortes de estos escaneos médicos 3D, el cargador de datos predeterminado puede generar múltiples lecturas del mismo escaneo del paciente, ya que los trabajadores no comparten memoria. Esta redundancia provoca retrasos importantes, especialmente cuando se trata de grandes conjuntos de datos.
Nuestro cargador de datos personalizado divide a los pacientes entre trabajadores, asegurando que cada escaneo 3D se lea solo una vez por trabajador. Este enfoque evita lecturas de disco redundantes y aprovecha el procesamiento paralelo para acelerar la carga de datos.
Las pruebas de rendimiento mostraron que nuestro cargador de datos personalizado generalmente supera al cargador de datos estándar, especialmente con lotes más pequeños y múltiples procesos de trabajo.
Nuestro dataLoader personalizado mejora la eficiencia de la carga de datos para grandes conjuntos de datos médicos en 3D al reducir las lecturas redundantes y maximizar el paralelismo. Esta mejora puede conducir a tiempos de capacitación más rápidos y una mejor utilización de los recursos de hardware.
Este blog fue escrito junto con mi colega Jingnan Jia .