PyTorch の DataLoader ( torch.utils.data.Dataloader
) は、ディープラーニング モデルのトレーニング用にデータを効率的にロードおよび前処理するための便利なツールです。デフォルトでは、PyTorch は単一ワーカー プロセス( num_workers=0
) を使用しますが、ユーザーは並列処理を活用してデータのロードを高速化するために、より大きな数値を指定できます。
ただし、これは汎用データローダーであり、並列化を提供しているにもかかわらず、特定のカスタムユースケースには適していません。この記事では、 torch.multiprocessing()
を使用して、3D 医療スキャンのデータセットから複数の 2D スライスをロードする速度を向上させる方法について説明します。
torch.utils.data.Dataset
について患者の 3D スキャンのセット (つまり、P1、P2、P3、…) と対応するスライスのリストが与えられたユースケースを想像します。私たちの目標は、反復ごとにスライスを出力するデータローダーを構築することです。以下のPython コードを確認してください。ここでは、 myDataset
と呼ばれる torch データセットを構築し、それをtorch.utils.data.Dataloader()
に渡しています。
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
私たちのユースケースの主な懸念は、 3D医療スキャンのサイズが大きいことです(ここではtime.sleep()
操作によってエミュレートされます)。そのため、
ディスクから読み取るには時間がかかる
3Dスキャンの大規模なデータセットは、ほとんどの場合、メモリに事前に読み込むことができない。
理想的には、各患者のスキャンは、それに関連付けられたすべてのスライスに対して 1 回だけ読み取る必要があります。ただし、データはtorch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
によってバッチ サイズに応じてワーカーに分割されるため、異なるワーカーが患者を 2 回読み取る可能性があります (以下の画像とログを確認してください)。
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
まとめると、 torch.utils.data.Dataloader
の既存の実装には次のような問題がある。
myDataset()
のコピーが渡されます(参照:
patientSliceList
を順番にループするため(下の画像を参照)、(patientId、sliceId)の組み合わせ間で自然なシャッフルは不可能です。(注:シャッフルは可能ですが、出力をメモリに保存する必要があります)
注: 各患者の 3D スキャンから、一連のスライスをまとめて返すこともできます。ただし、スライスに依存する 3D 配列 (たとえば、インタラクティブなリファインメント ネットワーク ( この作業の図 1 を参照)) も返したい場合は、データローダーのメモリ フットプリントが大幅に増加します。
torch.multiprocessing
の使用患者のスキャンが複数回読み取られるのを防ぐには、理想的には、各患者 ( 8 人の患者を想像してください) を特定の作業者が読み取る必要があります。
これを実現するために、torch データローダークラスと同じ内部ツール ( torch.multiprocessing()
など) を使用しますが、若干の違いがあります。カスタムデータローダーmyDataloader
のワークフロー図と コードを以下に示します。
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
上記のスニペット(代わりに8人の患者)には次の関数が含まれています。
__iter__()
- myDataloader()
はループなので、これが実際にループする関数になります。
_initWorkers()
- ここでは、個別の入力キューworkerInputQueues[workerId]
を使用してワーカー プロセスを作成します。これは、クラスが初期化されるときに呼び出されます。
fillInputQueues()
- この関数はループの開始時 (基本的には各エポックの開始時) に呼び出されます。個々のワーカーの入力キューを埋めます。
getSlice()
- これは、患者のボリュームからスライスを返すメインロジック関数です。 ここでコードを確認してください。
collate_tensor_fn()
- この関数は、torch リポジトリ ( torchv1.12.0 ) から直接コピーされ、データをまとめて処理するために使用されます。データローダーがデフォルト オプションと比較して高速化を実現するかどうかをテストするために、異なるワーカー数を使用して各データローダー ループの速度をテストします。実験では 2 つのパラメータを変更しました。
まず、おもちゃのデータセットで実験し、データローダーのパフォーマンスがはるかに高速であることを確認します。下の図を参照してください (または、 このコードで再現します)。
ここでは次のことがわかります
patientSlicesList
変数に患者 1 人あたり 5 つのスライスが含まれているためです。そのため、ワーカーは 2 番目の患者を読み取ってバッチの最後のインデックスに追加するまで待機する必要があります。次に、3Dスキャンを読み込み、スライスを抽出し、
私たちは、
また、ワーカー数を変えてデータ読み込み中のリソース使用率も監視しました。ワーカー数が増えると、CPU とメモリの使用量が増加することが確認されましたが、これは追加プロセスによって導入された並列処理によるものと予想されます。ユーザーは、最適なワーカー数を選択する際に、ハードウェアの制約とリソースの可用性を考慮する必要があります。
このブログ記事では、大規模な 3D 医療スキャンを含むデータセットを処理する際の PyTorch の標準 DataLoader の制限について検討し、 torch.multiprocessing
を使用してデータの読み込み効率を向上させるカスタム ソリューションを紹介しました。
これらの 3D 医療スキャンからのスライス抽出のコンテキストでは、ワーカーがメモリを共有しないため、デフォルトの dataLoader によって同じ患者のスキャンが複数回読み取られる可能性があります。この冗長性により、特に大規模なデータセットを処理する場合に大幅な遅延が発生します。
当社のカスタム dataLoader は、患者を作業者間で分割し、各 3D スキャンが作業者ごとに 1 回だけ読み取られるようにします。このアプローチにより、冗長なディスク読み取りが防止され、並列処理を利用してデータの読み込みが高速化されます。
パフォーマンス テストでは、特にバッチ サイズが小さく、ワーカー プロセスが複数ある場合、カスタム dataLoader のパフォーマンスが標準 dataLoader よりも優れていることが示されました。
当社のカスタム dataLoader は、冗長な読み取りを減らし、並列処理を最大化することで、大規模な 3D 医療データセットのデータ読み込み効率を高めます。この改善により、トレーニング時間が短縮され、ハードウェア リソースの利用率が向上します。
このブログは同僚のJingnan Jiaと共同で執筆しました。