PyTorch 的 DataLoader ( torch.utils.data.Dataloader
) 已经是一个有用的工具,可以高效地加载和预处理用于训练深度学习模型的数据。默认情况下,PyTorch 使用单工作进程( num_workers=0
),但用户可以指定更高的数字来利用并行性并加快数据加载速度。
然而,由于它是一个通用的数据加载器,即使它提供了并行化,它仍然不适合某些自定义用例。在这篇文章中,我们探讨了如何使用torch.multiprocessing()
加速从 3D 医学扫描数据集中加载多个 2D 切片。
torch.utils.data.Dataset
我设想一个用例,其中给出了一组患者的 3D 扫描(即 P1、P2、P3、…)和相应切片的列表;我们的目标是构建一个在每次迭代中输出切片的数据加载器。检查下面的Python 代码,我们构建了一个名为myDataset
的 torch 数据集,并将其传递给torch.utils.data.Dataloader()
。
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
我们用例的主要问题是3D 医学扫描尺寸很大(这里由time.sleep()
操作模拟),因此
从磁盘读取它们可能非常耗时
并且大多数情况下,大量的 3D 扫描数据集无法预先读入内存
理想情况下,我们应该只读取一次与患者相关的所有切片的扫描结果。但由于数据由torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
根据批大小拆分到 worker 中,因此不同的 worker 可能会读取同一个患者两次(请查看下面的图片和日志)。
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
总而言之,以下是torch.utils.data.Dataloader
现有实现中存在的问题
myDataset()
的副本(参考:
patientSliceList
(见下图),因此(patientId,sliceId)组合之间不可能进行自然的改组。(注意:可以改组,但这涉及将输出存储在内存中)
注意:也可以只返回来自每个患者 3D 扫描的一组切片。但如果我们还希望返回依赖于切片的 3D 数组(例如,交互式细化网络( 参见本作品的图 1 ),那么这会大大增加数据加载器的内存占用。
torch.multiprocessing
为了防止多次读取患者扫描结果,理想情况下我们需要每个患者(假设是 8 个患者)由特定的工作人员读取。
为了实现这一点,我们使用与 torch dataloader 类相同的内部工具(即torch.multiprocessing()
),但略有不同。查看下面的工作流程图和 代码,了解我们的自定义数据加载器 - myDataloader
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
上面的代码片段(改为 8 名患者)包含以下函数
__iter__()
- 由于myDataloader()
是一个循环,所以这是它实际循环的函数。
_initWorkers()
- 在这里,我们创建工作进程及其各自的输入队列workerInputQueues[workerId]
。 初始化类时会调用此方法。
fillInputQueues()
- 当我们开始循环时(基本上在每个 epoch 开始时)会调用此函数。它会填充单个 worker 的输入队列。
getSlice()
- 这是返回患者体切片的主要逻辑函数。请在此处查看代码。
collate_tensor_fn()
- 此函数直接从 torch repo - torchv1.12.0复制而来,用于将数据批量处理在一起。为了测试我们的数据加载器是否比默认选项提供了加速,我们使用不同的工作器数量测试每个数据加载器循环的速度。我们在实验中改变了两个参数:
我们首先用我们的玩具数据集进行实验,发现我们的数据加载器执行速度更快。参见下图(或使用此代码重现)
在这里,我们可以看到以下内容
patientSlicesList
变量包含每个患者的 5 个切片。因此,工作人员需要等待读取第二个患者,才能将其添加到批处理的最后一个索引中。然后,我们对一个真实的数据集进行基准测试,其中加载了 3D 扫描,提取了切片,
我们观察到
我们还监控了数据加载过程中使用不同数量的工作器时的资源利用率。随着工作器数量的增加,我们观察到 CPU 和内存使用率增加,这是由于额外进程引入了并行性而导致的。用户在选择最佳工作器数量时应考虑其硬件限制和资源可用性。
在这篇博文中,我们探讨了 PyTorch 的标准 DataLoader 在处理包含大型 3D 医学扫描的数据集时的局限性,并提出了一种使用torch.multiprocessing
的自定义解决方案来提高数据加载效率。
在从这些 3D 医学扫描中提取切片时,默认的 dataLoader 可能会导致对同一患者扫描进行多次读取,因为工作器不共享内存。这种冗余会导致严重的延迟,尤其是在处理大型数据集时。
我们的自定义 dataLoader 将患者分派给各个工作人员,确保每个工作人员只读取一次 3D 扫描。这种方法可防止重复磁盘读取,并利用并行处理来加快数据加载速度。
性能测试表明,我们的自定义 dataLoader 通常优于标准 dataLoader,特别是在批次大小较小和有多个工作进程的情况下。
我们的自定义 dataLoader 通过减少冗余读取和最大化并行性来提高大型 3D 医疗数据集的数据加载效率。这一改进可以缩短训练时间并更好地利用硬件资源。
这篇博客是我和我的同事贾静南共同撰写的。