Introduction PyTorch's DataLoader (torch.utils.data.Dataloader) is already a useful tool for efficiently loading and preprocessing data for training deep learning models. By default, PyTorch uses a single-worker process (num_workers=0), but users can specify a higher number to leverage parallelism and speed up data loading. However, since it is a general-purpose dataloader, and even though it offers parallelization, it is still not suitable for certain custom use cases. In this post, we explore how we can speed up the loading of multiple 2D slices from a dataset of 3D medical scans using torch.multiprocessing(). Our torch.utils.data.Dataset Imagine a use case in which given a set of 3D scans for patients (i.e., P1, P2, P3, …) and a list of corresponding slices; our goal is to build a dataloader that outputs a slice in every iteration. Check the Python code below where we build a torch dataset called myDataset, and pass it into torch.utils.data.Dataloader(). # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0]) The main concern with our use case is that 3D medical scans are large in size (emulated here by the time.sleep() operation) and hence reading them from disk can be time intensive and a large dataset of 3D scans in most cases cannot be pre-read into memory Ideally, we should only read each patient scan once for all the slices associated with it. But since data is split by torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) into workers depending on the batch size, there is a possibility for different workers to read a patient twice (check the image and log below). - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])] To summarize, here are the issues with the existing implementation of torch.utils.data.Dataloader Each of the workers is passed a copy of the myDataset() (Ref: torch v1.2.0), and since they do not have any shared memory, it leads to a double disk read of a patient’s 3D scan. Moreover, since the torch sequentially loops over patientSliceList (see image below), no natural shuffling is possible between (patientId, sliceId) combos. (Note: one can shuffle, but that involves storing outputs in memory) Note: One could also just return a bunch of slices together from each patients 3D scan. But if we wish to also return slice-dependent 3D arrays (for example, interactive refinement networks (see Fig1 of this work), then this greatly increases the memory footprint of your dataloader. Using torch.multiprocessing To prevent multiple reads of patient scans, we would ideally need each patient (let’s imagine 8 patients) to be read by a particular worker. To achieve this, we use the same internal tools as the torch dataloader class (i.e., torch.multiprocessing()) but with a slight difference. Check the workflow figure and code below for our custom dataloader - myDataloader # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses() The snippet above (with 8 patients instead) contains the following functions __iter__() - Since myDataloader() is a loop, this is the function it actually loops over. _initWorkers() - Here, we create our worker processes with their individual input queues workerInputQueues[workerId]. This is called when the class is initialized. fillInputQueues() - This function is called when we begin the loop (essentially at the start of every epoch). It fills up the individual worker’s input queue. getSlice() - This is the main logic function that returns a slice from a patient volume. Check the code here. collate_tensor_fn() - This function is directly copied from the torch repo - torchv1.12.0 and is used to batch data together. Performance To test whether our dataloader offers a speedup compared to the default option, we test the speed of each dataloader loop using different worker counts. We varied two parameters in our experiments: Number of Workers: We tested 1, 2, 4, and 8 worker processes. Batch Size: We evaluated different batch sizes ranging from 1 to 8. Toy Dataset We first experiment with our toy dataset and see that our dataloader performs much faster. See the figure below (or reproduce with this code) Here, we can see the following When using a single worker, both dataloaders are the same. When using additional workers (i.e. 2,4,8), there is a speedup in both dataloaders, however, the speedup is much higher in our custom dataloader. When using a batch size of 6 (as compared to 1,2,3,4), there is a small hit in the performance. This is because, in our toy dataset, the patientSlicesList variable contains 5 slices per patient. So, the worker needs to wait to read the second patient to add to the last index of the batch. Real World Dataset We then benchmark a real dataset where 3D scans are loaded, a slice is extracted, some additional preprocessing is done, and then the slice and other arrays are returned. See the figure below for results. We observed that increasing the number of worker (and batch sizes) processes generally led to faster data loading and therefore may lead to faster training. For smaller batch sizes (e.g., 1 or 2), doubling the number of workers resulted in much larger speedups. However, as the batch size increased, the marginal improvement from adding more workers diminished. Resource Utilization We also monitored resource utilization during data loading with varying worker counts. With a higher number of workers, we observed increased CPU and memory usage, which is expected due to the parallelism introduced by additional processes. Users should consider their hardware constraints and resource availability when choosing the optimal worker count. Summary In this blog post, we explored the limitations of PyTorch's standard DataLoader when dealing with datasets containing large 3D medical scans and presented a custom solution using torch.multiprocessing to improve data loading efficiency. In the context of slice extraction from these 3D medical scans, the default dataLoader can potentially lead to multiple reads of the same patient scan as workers do not share memory. This redundancy causes significant delays, particularly when dealing with large datasets. Our custom dataLoader splits patients between workers, ensuring that each 3D scan is read only once per worker. This approach prevents redundant disk reads and leverages parallel processing to speed up data loading. Performance testing showed that our custom dataLoader generally outperforms the standard dataLoader, especially with smaller batch sizes and multiple worker processes. However, the performance gains diminished with larger batch sizes. Our custom dataLoader enhances data loading efficiency for large 3D medical datasets by reducing redundant reads and maximizing parallelism. This improvement can lead to faster training times and better utilization of hardware resources. This blog was written together with my colleague Jingnan Jia. Introduction PyTorch's DataLoader ( torch.utils.data.Dataloader ) is already a useful tool for efficiently loading and preprocessing data for training deep learning models. By default, PyTorch uses a single-worker process ( num_workers=0 ), but users can specify a higher number to leverage parallelism and speed up data loading. torch.utils.data.Dataloader single-worker process num_workers=0 However, since it is a general-purpose dataloader, and even though it offers parallelization, it is still not suitable for certain custom use cases. In this post, we explore how we can speed up the loading of multiple 2D slices from a dataset of 3D medical scans using torch.multiprocessing() . speed up the loading of multiple 2D slices from a dataset of 3D medical scans torch.multiprocessing() Our torch.utils.data.Dataset torch.utils.data.Dataset I magine a use case in which given a set of 3D scans for patients (i.e., P1, P2, P3, …) and a list of corresponding slices; our goal is to build a dataloader that outputs a slice in every iteration . Check the Python code below where we build a torch dataset called myDataset , and pass it into torch.utils.data.Dataloader() . I outputs a slice in every iteration Python code myDataset torch.utils.data.Dataloader() # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0]) # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0]) The main concern with our use case is that 3D medical scans are large in size ( emulated here by the time.sleep() operation ) and hence 3D medical scans are large in size emulated here by the time.sleep() operation reading them from disk can be time intensive and a large dataset of 3D scans in most cases cannot be pre-read into memory reading them from disk can be time intensive reading them from disk can be time intensive and a large dataset of 3D scans in most cases cannot be pre-read into memory and a large dataset of 3D scans in most cases cannot be pre-read into memory Ideally, we should only read each patient scan once for all the slices associated with it. But since data is split by torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) into workers depending on the batch size, there is a possibility for different workers to read a patient twice ( check the image and log below ). torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) check the image and log below - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])] - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])] To summarize, here are the issues with the existing implementation of torch.utils.data.Dataloader torch.utils.data.Dataloader Each of the workers is passed a copy of the myDataset() (Ref: torch v1.2.0), and since they do not have any shared memory, it leads to a double disk read of a patient’s 3D scan. Each of the workers is passed a copy of the myDataset() (Ref: torch v1.2.0 ), and since they do not have any shared memory, it leads to a double disk read of a patient’s 3D scan. myDataset() torch v1.2. 0 torch v1.2. Moreover, since the torch sequentially loops over patientSliceList (see image below), no natural shuffling is possible between (patientId, sliceId) combos. (Note: one can shuffle, but that involves storing outputs in memory) Moreover, since the torch sequentially loops over patientSliceList ( see image below ), no natural shuffling is possible between (patientId, sliceId) combos. ( Note: one can shuffle, but that involves storing outputs in memory ) patientSliceList see image below Note: one can shuffle, but that involves storing outputs in memory Note: One could also just return a bunch of slices together from each patients 3D scan. But if we wish to also return slice-dependent 3D arrays (for example, interactive refinement networks ( see Fig1 of this work ), then this greatly increases the memory footprint of your dataloader. see Fig1 of this work Using torch.multiprocessing torch.multiprocessing To prevent multiple reads of patient scans , we would ideally need each patient ( let’s imagine 8 patients ) to be read by a particular worker. prevent multiple reads of patient scans let’s imagine 8 patients To achieve this, we use the same internal tools as the torch dataloader class (i.e., torch.multiprocessing() ) but with a slight difference. Check the workflow figure and code below for our custom dataloader - myDataloader torch.multiprocessing() code myDataloader # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses() # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses() The snippet above ( with 8 patients instead ) contains the following functions with 8 patients instead __iter__() - Since myDataloader() is a loop, this is the function it actually loops over. __iter__() - Since myDataloader() is a loop, this is the function it actually loops over. __iter__() myDataloader() _initWorkers() - Here, we create our worker processes with their individual input queues workerInputQueues[workerId]. This is called when the class is initialized. _initWorkers() - Here, we create our worker processes with their individual input queues workerInputQueues[workerId] . This is called when the class is initialized. _initWorkers() workerInputQueues[workerId] fillInputQueues() - This function is called when we begin the loop (essentially at the start of every epoch). It fills up the individual worker’s input queue. fillInputQueues() - This function is called when we begin the loop ( essentially at the start of every epoch ). It fills up the individual worker’s input queue. fillInputQueues() essentially at the start of every epoch getSlice() - This is the main logic function that returns a slice from a patient volume. Check the code here. getSlice() - This is the main logic function that returns a slice from a patient volume. Check the code here . getSlice() here collate_tensor_fn() - This function is directly copied from the torch repo - torchv1.12.0 and is used to batch data together. collate_tensor_fn() - This function is directly copied from the torch repo - torchv1.12.0 and is used to batch data together. collate_tensor_fn() torchv1.12.0 Performance To test whether our dataloader offers a speedup compared to the default option, we test the speed of each dataloader loop using different worker counts . We varied two parameters in our experiments: using different worker counts Number of Workers: We tested 1, 2, 4, and 8 worker processes. Batch Size: We evaluated different batch sizes ranging from 1 to 8. Number of Workers : We tested 1, 2, 4, and 8 worker processes. Number of Workers Batch Size : We evaluated different batch sizes ranging from 1 to 8. Batch Size Toy Dataset We first experiment with our toy dataset and see that our dataloader performs much faster. See the figure below (or reproduce with this code ) this code Here, we can see the following When using a single worker, both dataloaders are the same. When using a single worker, both dataloaders are the same. When using additional workers (i.e. 2,4,8), there is a speedup in both dataloaders, however, the speedup is much higher in our custom dataloader. When using additional workers (i.e. 2,4,8), there is a speedup in both dataloaders, however, the speedup is much higher in our custom dataloader. When using a batch size of 6 (as compared to 1,2,3,4), there is a small hit in the performance. This is because, in our toy dataset, the patientSlicesList variable contains 5 slices per patient. So, the worker needs to wait to read the second patient to add to the last index of the batch. When using a batch size of 6 (as compared to 1,2,3,4), there is a small hit in the performance. This is because, in our toy dataset, the patientSlicesList variable contains 5 slices per patient. So, the worker needs to wait to read the second patient to add to the last index of the batch. patientSlicesList Real World Dataset We then benchmark a real dataset where 3D scans are loaded, a slice is extracted, some additional preprocessing is done , and then the slice and other arrays are returned. See the figure below for results. some additional preprocessing is done We observed that increasing the number of worker (and batch sizes) processes generally led to faster data loading and therefore may lead to faster training. For smaller batch sizes (e.g., 1 or 2), doubling the number of workers resulted in much larger speedups. However, as the batch size increased, the marginal improvement from adding more workers diminished. increasing the number of worker (and batch sizes) processes generally led to faster data loading Resource Utilization Resource Utilization We also monitored resource utilization during data loading with varying worker counts. With a higher number of workers, we observed increased CPU and memory usage, which is expected due to the parallelism introduced by additional processes. Users should consider their hardware constraints and resource availability when choosing the optimal worker count. Summary In this blog post, we explored the limitations of PyTorch's standard DataLoader when dealing with datasets containing large 3D medical scans and presented a custom solution using torch.multiprocessing to improve data loading efficiency. In the context of slice extraction from these 3D medical scans, the default dataLoader can potentially lead to multiple reads of the same patient scan as workers do not share memory. This redundancy causes significant delays, particularly when dealing with large datasets. Our custom dataLoader splits patients between workers, ensuring that each 3D scan is read only once per worker. This approach prevents redundant disk reads and leverages parallel processing to speed up data loading. Performance testing showed that our custom dataLoader generally outperforms the standard dataLoader, especially with smaller batch sizes and multiple worker processes. However, the performance gains diminished with larger batch sizes. In this blog post, we explored the limitations of PyTorch's standard DataLoader when dealing with datasets containing large 3D medical scans and presented a custom solution using torch.multiprocessing to improve data loading efficiency. In this blog post, we explored the limitations of PyTorch's standard DataLoader when dealing with datasets containing large 3D medical scans and presented a custom solution using torch.multiprocessing to improve data loading efficiency. torch.multiprocessing In the context of slice extraction from these 3D medical scans, the default dataLoader can potentially lead to multiple reads of the same patient scan as workers do not share memory. This redundancy causes significant delays, particularly when dealing with large datasets. In the context of slice extraction from these 3D medical scans, the default dataLoader can potentially lead to multiple reads of the same patient scan as workers do not share memory. This redundancy causes significant delays, particularly when dealing with large datasets. Our custom dataLoader splits patients between workers, ensuring that each 3D scan is read only once per worker. This approach prevents redundant disk reads and leverages parallel processing to speed up data loading. Our custom dataLoader splits patients between workers, ensuring that each 3D scan is read only once per worker. This approach prevents redundant disk reads and leverages parallel processing to speed up data loading. Performance testing showed that our custom dataLoader generally outperforms the standard dataLoader, especially with smaller batch sizes and multiple worker processes. However, the performance gains diminished with larger batch sizes. Performance testing showed that our custom dataLoader generally outperforms the standard dataLoader, especially with smaller batch sizes and multiple worker processes. However, the performance gains diminished with larger batch sizes. However, the performance gains diminished with larger batch sizes. Our custom dataLoader enhances data loading efficiency for large 3D medical datasets by reducing redundant reads and maximizing parallelism. This improvement can lead to faster training times and better utilization of hardware resources. Our custom dataLoader enhances data loading efficiency for large 3D medical datasets by reducing redundant reads and maximizing parallelism. This improvement can lead to faster training times and better utilization of hardware resources. This blog was written together with my colleague Jingnan Jia . Jingnan Jia