paint-brush
Torch.multiprocessing を使用して Torch データローダーの並列化を改善する方法by@pixelperfectionist
376
376

Torch.multiprocessing を使用して Torch データローダーの並列化を改善する方法

Prerak Mody13m2024/06/10
Read on Terminal Reader

PyTorch データローダーは、ディープラーニング モデルのトレーニング用にデータを効率的に読み込み、前処理するためのツールです。この記事では、カスタム データローダーと torch.multiprocessing を使用してこのプロセスを高速化する方法を探ります。3D 医療スキャンのデータセットから複数の 2D スライスを読み込む実験を行います。
featured image - Torch.multiprocessing を使用して Torch データローダーの並列化を改善する方法
Prerak Mody HackerNoon profile picture
0-item

導入

PyTorch の DataLoader ( torch.utils.data.Dataloader ) は、ディープラーニング モデルのトレーニング用にデータを効率的にロードおよび前処理するための便利なツールです。デフォルトでは、PyTorch は単一ワーカー プロセス( num_workers=0 ) を使用しますが、ユーザーは並列処理を活用してデータのロードを高速化するために、より大きな数値を指定できます。


ただし、これは汎用データローダーであり、並列化を提供しているにもかかわらず、特定のカスタムユースケースには適していません。この記事では、 torch.multiprocessing()を使用して、3D 医療スキャンのデータセットから複数の 2D スライスをロードする速度を向上させる方法について説明します。


各患者の 3D スキャンからスライスのセットを抽出したいと考えています。これらの患者は大規模なデータセットの一部です。



torch.utils.data.Datasetについて

患者の 3D スキャンのセット (つまり、P1、P2、P3、…) と対応するスライスのリストが与えられたユースケースを想像します。私たちの目標は、反復ごとにスライスを出力するデータローダーを構築することです。以下のPython コードを確認してください。ここでは、 myDatasetと呼ばれる torch データセットを構築し、それをtorch.utils.data.Dataloader()に渡しています。


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


私たちのユースケースの主な懸念は、 3D医療スキャンのサイズが大きいことです(ここではtime.sleep()操作によってエミュレートされます)。そのため、

  • ディスクから読み取るには時間がかかる

  • 3Dスキャンの大規模なデータセットは、ほとんどの場合、メモリに事前に読み込むことができない。


理想的には、各患者のスキャンは、それに関連付けられたすべてのスライスに対して 1 回だけ読み取る必要があります。ただし、データはtorch.utils.data.dataloader(myDataset, batch_size=b, workers=n)によってバッチ サイズに応じてワーカーに分割されるため、異なるワーカーが患者を 2 回読み取る可能性があります (以下の画像とログを確認してください)。

Torch は、バッチ サイズ (この場合は 3) に応じて、データセットの読み込みを各ワーカーに分割します。これにより、各患者は複数のワーカーによって読み取られます。


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


まとめると、 torch.utils.data.Dataloaderの既存の実装には次のような問題がある。

  • 各ワーカーにはmyDataset()のコピーが渡されます(参照:トーチ v1.2.0 )、共有メモリがないため、患者の 3D スキャンが 2 重のディスクから読み取られることになります。


  • さらに、トーチはpatientSliceListを順番にループするため(下の画像を参照)、(patientId、sliceId)の組み合わせ間で自然なシャッフルは不可能です。(注:シャッフルは可能ですが、出力をメモリに保存する必要があります


標準の torch.utils.data.Dataloader() には、ワーカーから出力が抽出される方法をグローバルに管理する内部キューがあります。特定のワーカーでデータが準備されていても、このグローバル キューを尊重する必要があるため、データを出力することはできません。



注: 各患者の 3D スキャンから、一連のスライスをまとめて返すこともできます。ただし、スライスに依存する 3D 配列 (たとえば、インタラクティブなリファインメント ネットワーク ( この作業の図 1 を参照)) も返したい場合は、データローダーのメモリ フットプリントが大幅に増加します。



torch.multiprocessingの使用

患者のスキャンが複数回読み取られるのを防ぐには、理想的には、各患者 ( 8 人の患者を想像してください) を特定の作業者が読み取る必要があります。

ここでは、各作業者は患者(のセット)の読み取りに集中します。


これを実現するために、torch データローダークラスと同じ内部ツール ( torch.multiprocessing()など) を使用しますが、若干の違いがあります。カスタムデータローダーmyDataloaderのワークフロー図と コードを以下に示します。

ここで、出力キュー (下部) には各ワーカーからの出力が含まれています。各ワーカーは、特定の患者セットのみの入力情報 (上部に表示されている入力キュー) を受け取ります。これにより、患者の 3D スキャンが複数回読み取られるのを防ぎます。



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


上記のスニペット(代わりに8人の患者)には次の関数が含まれています。

  • __iter__() - myDataloader()はループなので、これが実際にループする関数になります。


  • _initWorkers() - ここでは、個別の入力キューworkerInputQueues[workerId]を使用してワーカー プロセスを作成します。これは、クラスが初期化されるときに呼び出されます。


  • fillInputQueues() - この関数はループの開始時 (基本的には各エポックの開始時) に呼び出されます。個々のワーカーの入力キューを埋めます。


  • getSlice() - これは、患者のボリュームからスライスを返すメインロジック関数です。 ここでコードを確認してください。


  • collate_tensor_fn() - この関数は、torch リポジトリ ( torchv1.12.0 ) から直接コピーされ、データをまとめて処理するために使用されます。


パフォーマンス

データローダーがデフォルト オプションと比較して高速化を実現するかどうかをテストするために、異なるワーカー数を使用して各データローダー ループの速度をテストします。実験では 2 つのパラメータを変更しました。


  • 従業員数: 1、2、4、8 のワーカー プロセスをテストしました。
  • バッチサイズ: 1 から 8 までのさまざまなバッチ サイズを評価しました。

おもちゃのデータセット

まず、おもちゃのデータセットで実験し、データローダーのパフォーマンスがはるかに高速であることを確認します。下の図を参照してください (または、 このコードで再現します)。
合計時間が短く、反復回数/秒数が多いほど、データローダーの性能は向上します。

ここでは次のことがわかります

  • 単一のワーカーを使用する場合、両方のデータローダーは同じになります。


  • 追加のワーカー (つまり 2、4、8) を使用すると、両方のデータローダーで速度が向上しますが、カスタム データローダーでは速度の向上がはるかに大きくなります。


  • バッチ サイズを 6 にすると (1、2、3、4 と比較)、パフォーマンスに若干の影響が出ます。これは、このおもちゃのデータセットでは、 patientSlicesList変数に患者 1 人あたり 5 つのスライスが含まれているためです。そのため、ワーカーは 2 番目の患者を読み取ってバッチの最後のインデックスに追加するまで待機する必要があります。

実世界データセット

次に、3Dスキャンを読み込み、スライスを抽出し、追加の前処理が行われますすると、スライスとその他の配列が返されます。結果については下の図を参照してください。


私たちは、ワーカー数(およびバッチサイズ)を増やすと、一般的にデータの読み込みが速くなります。そのため、トレーニングが高速化される可能性があります。バッチ サイズが小さい場合 (1 または 2 など)、ワーカー数を 2 倍にすると、速度が大幅に向上します。ただし、バッチ サイズが大きくなるにつれて、ワーカーを追加することで得られる限界的な改善は減少します。

反復回数/秒が高いほど、データローダーは高速になります。

リソースの活用

また、ワーカー数を変えてデータ読み込み中のリソース使用率も監視しました。ワーカー数が増えると、CPU とメモリの使用量が増加することが確認されましたが、これは追加プロセスによって導入された並列処理によるものと予想されます。ユーザーは、最適なワーカー数を選択する際に、ハードウェアの制約とリソースの可用性を考慮する必要があります。

まとめ

  1. このブログ記事では、大規模な 3D 医療スキャンを含むデータセットを処理する際の PyTorch の標準 DataLoader の制限について検討し、 torch.multiprocessingを使用してデータの読み込み効率を向上させるカスタム ソリューションを紹介しました。


  2. これらの 3D 医療スキャンからのスライス抽出のコンテキストでは、ワーカーがメモリを共有しないため、デフォルトの dataLoader によって同じ患者のスキャンが複数回読み取られる可能性があります。この冗長性により、特に大規模なデータセットを処理する場合に大幅な遅延が発生します。


  3. 当社のカスタム dataLoader は、患者を作業者間で分割し、各 3D スキャンが作業者ごとに 1 回だけ読み取られるようにします。このアプローチにより、冗長なディスク読み取りが防止され、並列処理を利用してデータの読み込みが高速化されます。


  4. パフォーマンス テストでは、特にバッチ サイズが小さく、ワーカー プロセスが複数ある場合、カスタム dataLoader のパフォーマンスが標準 dataLoader よりも優れていることが示されました。


    1. ただし、バッチ サイズが大きくなるにつれてパフォーマンスの向上は減少します。


当社のカスタム dataLoader は、冗長な読み取りを減らし、並列処理を最大化することで、大規模な 3D 医療データセットのデータ読み込み効率を高めます。この改善により、トレーニング時間が短縮され、ハードウェア リソースの利用率が向上します。


このブログは同僚のJingnan Jiaと共同で執筆しました。