paint-brush
如何使用 Torch.multiprocessing 提高 Torch 数据加载器的并行化经过@pixelperfectionist
439 讀數
439 讀數

如何使用 Torch.multiprocessing 提高 Torch 数据加载器的并行化

经过 Prerak Mody13m2024/06/10
Read on Terminal Reader

太長; 讀書

PyTorch dataloader 是一种用于高效加载和预处理数据以训练深度学习模型的工具。在本文中,我们将探讨如何使用自定义 dataloader 和 torch.multiprocessing 来加速此过程。我们尝试从 3D 医学扫描数据集中加载多个 2D 切片。
featured image - 如何使用 Torch.multiprocessing 提高 Torch 数据加载器的并行化
Prerak Mody HackerNoon profile picture
0-item

介绍

PyTorch 的 DataLoader ( torch.utils.data.Dataloader ) 已经是一个有用的工具,可以高效地加载和预处理用于训练深度学习模型的数据。默认情况下,PyTorch 使用单工作进程( num_workers=0 ),但用户可以指定更高的数字来利用并行性并加快数据加载速度。


然而,由于它是一个通用的数据加载器,即使它提供了并行化,它仍然不适合某些自定义用例。在这篇文章中,我们探讨了如何使用torch.multiprocessing()加速从 3D 医学扫描数据集中加载多个 2D 切片


我们希望从每位患者的 3D 扫描中提取一组切片。这些患者是大型数据集的一部分。



我们的torch.utils.data.Dataset

设想一个用例,其中给出了一组患者的 3D 扫描(即 P1、P2、P3、…)和相应切片的列表;我们的目标是构建一个在每次迭代中输出切片的数据加载器。检查下面的Python 代码,我们构建了一个名为myDataset的 torch 数据集,并将其传递给torch.utils.data.Dataloader()


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


我们用例的主要问题是3D 医学扫描尺寸很大(这里由time.sleep()操作模拟),因此

  • 从磁盘读取它们可能非常耗时

  • 并且大多数情况下,大量的 3D 扫描数据集无法预先读入内存


理想情况下,我们应该只读取一次与患者相关的所有切片的扫描结果。但由于数据由torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)根据批大小拆分到 worker 中,因此不同的 worker 可能会读取同一个患者两次(请查看下面的图片和日志)。

Torch 根据批次大小(本例中为 =3)将数据集的加载拆分到每个工作器中。因此,每个患者由多个工作器读取。


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


总而言之,以下是torch.utils.data.Dataloader现有实现中存在的问题

  • 每位工作人员都会收到一份myDataset()的副本(参考:火炬 v1.2. 0 ),而且由于它们没有任何共享内存,这会导致对患者的 3D 扫描进行双磁盘读取。


  • 此外,由于 torch 顺序循环遍历patientSliceList见下图),因此(patientId,sliceId)组合之间不可能进行自然的改组。(注意:可以改组,但这涉及将输出存储在内存中


标准 torch.utils.data.Dataloader() 有一个内部队列,用于全局管理如何从 worker 中提取输出。即使某个 worker 已准备好数据,它也无法输出数据,因为它必须遵守这个全局队列。



注意:也可以只返回来自每个患者 3D 扫描的一组切片。但如果我们还希望返回依赖于切片的 3D 数组(例如,交互式细化网络( 参见本作品的图 1 ),那么这会大大增加数据加载器的内存占用。



使用torch.multiprocessing

为了防止多次读取患者扫描结果,理想情况下我们需要每个患者(假设是 8 个患者)由特定的工作人员读取。

在这里,每个工作人员都专注于阅读一组(多个)病人。


为了实现这一点,我们使用与 torch dataloader 类相同的内部工具(即torch.multiprocessing() ),但略有不同。查看下面的工作流程图和 代码,了解我们的自定义数据加载器 - myDataloader

这里,输出队列(底部)包含来自每个工作者的输出。每个工作者仅接收一组特定患者的输入信息(顶部显示的输入队列)。因此,这可以防止多次读取患者的 3D 扫描。



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


上面的代码片段(改为 8 名患者)包含以下函数

  • __iter__() - 由于myDataloader()是一个循环,所以这是它实际循环的函数。


  • _initWorkers() - 在这里,我们创建工作进程及其各自的输入队列workerInputQueues[workerId] 。 初始化类时会调用此方法。


  • fillInputQueues() - 当我们开始循环时(基本上在每个 epoch 开始时)会调用此函数。它会填充单个 worker 的输入队列。


  • getSlice() - 这是返回患者体切片的主要逻辑函数。请在此处查看代码。


  • collate_tensor_fn() - 此函数直接从 torch repo - torchv1.12.0复制而来,用于将数据批量处理在一起。


表现

为了测试我们的数据加载器是否比默认选项提供了加速,我们使用不同的工作器数量测试每个数据加载器循环的速度。我们在实验中改变了两个参数:


  • 工人数量:我们测试了 1、2、4 和 8 个工作进程。
  • 批次大小:我们评估了从 1 到 8 的不同批次大小。

玩具数据集

我们首先用我们的玩具数据集进行实验,发现我们的数据加载器执行速度更快。参见下图(或使用此代码重现)
更短的总时间和更高的每秒迭代次数意味着更好的数据加载器。

在这里,我们可以看到以下内容

  • 当使用单个工作器时,两个数据加载器是相同的。


  • 当使用额外的工作器(即 2,4,8)时,两个数据加载器的速度都会加快,但是,我们的自定义数据加载器的速度要高得多。


  • 当使用批处理大小为 6(与 1、2、3、4 相比)时,性能会略有下降。这是因为在我们的玩具数据集中, patientSlicesList变量包含每个患者的 5 个切片。因此,工作人员需要等待读取第二个患者,才能将其添加到批处理的最后一个索引中。

真实世界数据集

然后,我们对一个真实的数据集进行基准测试,其中加载了 3D 扫描,提取了切片,进行了一些额外的预处理,然后返回切片和其他数组,结果见下图。


我们观察到增加工作进程数(和批处理大小)通常可以加快数据加载速度因此可能会加快训练速度。对于较小的批次大小(例如 1 或 2),将工人数量增加一倍可带来更大的加速。但是,随着批次大小的增加,增加更多工人所带来的边际改善会减少。

每秒的迭代次数越高,数据加载速度越快。

资源利用率

我们还监控了数据加载过程中使用不同数量的工作器时的资源利用率。随着工作器数量的增加,我们观察到 CPU 和内存使用率增加,这是由于额外进程引入了并行性而导致的。用户在选择最佳工作器数量时应考虑其硬件限制和资源可用性。

概括

  1. 在这篇博文中,我们探讨了 PyTorch 的标准 DataLoader 在处理包含大型 3D 医学扫描的数据集时的局限性,并提出了一种使用torch.multiprocessing的自定义解决方案来提高数据加载效率。


  2. 在从这些 3D 医学扫描中提取切片时,默认的 dataLoader 可能会导致对同一患者扫描进行多次读取,因为工作器不共享内存。这种冗余会导致严重的延迟,尤其是在处理大型数据集时。


  3. 我们的自定义 dataLoader 将患者分派给各个工作人员,确保每个工作人员只读取一次 3D 扫描。这种方法可防止重复磁盘读取,并利用并行处理来加快数据加载速度。


  4. 性能测试表明,我们的自定义 dataLoader 通常优于标准 dataLoader,特别是在批次大小较小和有多个工作进程的情况下。


    1. 然而,随着批次大小的增大,性能增益会降低。


我们的自定义 dataLoader 通过减少冗余读取和最大化并行性来提高大型 3D 医疗数据集的数据加载效率。这一改进可以缩短训练时间并更好地利用硬件资源。


这篇博客是我和我的同事贾静南共同撰写的。