PyTorch'un DataLoader'ı ( torch.utils.data.Dataloader
), derin öğrenme modellerinin eğitimi için verileri verimli bir şekilde yüklemek ve ön işlemek için zaten kullanışlı bir araçtır. Varsayılan olarak PyTorch, tek çalışanlı bir işlem kullanır ( num_workers=0
), ancak kullanıcılar paralellikten yararlanmak ve veri yüklemeyi hızlandırmak için daha yüksek bir sayı belirtebilir.
Ancak genel amaçlı bir veri yükleyici olduğundan ve paralelleştirme sunmasına rağmen yine de bazı özel kullanım durumları için uygun değildir. Bu yazıda torch.multiprocessing()
yöntemini kullanarak 3 boyutlu tıbbi taramalardan oluşan bir veri kümesinden birden fazla 2 boyutlu dilimin yüklenmesini nasıl hızlandırabileceğimizi araştırıyoruz.
torch.utils.data.Dataset
Hastalar için bir dizi 3 boyutlu taramanın (yani, P1, P2, P3,…) ve karşılık gelen dilimlerin bir listesinin verildiği bir kullanım durumu hayal ediyorum ; Amacımız her yinelemede bir dilim çıktısı veren bir veri yükleyici oluşturmaktır. myDataset
adında bir meşale veri kümesi oluşturduğumuz aşağıdaki Python kodunu kontrol edin ve bunu torch.utils.data.Dataloader()
dosyasına aktarın.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
Kullanım durumumuzla ilgili temel endişe , 3 boyutlu tıbbi taramaların boyutunun büyük olmasıdır ( burada time.sleep()
işlemiyle taklit edilmiştir) ve dolayısıyla
bunları diskten okumak zaman alıcı olabilir
ve çoğu durumda büyük bir 3D tarama veri kümesi önceden belleğe okunamaz
İdeal olarak, her hasta taramasını, onunla ilişkili tüm dilimler için yalnızca bir kez okumalıyız. Ancak veriler torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
tarafından toplu iş boyutuna bağlı olarak çalışanlara bölündüğünden, farklı çalışanların bir hastayı iki kez okuma olasılığı vardır ( resmi kontrol edin ve günlüğe kaydedin) altında ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Özetlemek gerekirse, torch.utils.data.Dataloader
mevcut uygulamasıyla ilgili sorunlar şunlardır:
myDataset()
'in bir kopyası iletilir (Ref:
patientSliceList
( aşağıdaki resme bakın ) üzerinde sırayla döndüğü için, (hastaKimliği, dilimKimliği) kombinasyonları arasında doğal bir karıştırma mümkün değildir. ( Not: karıştırılabilir, ancak bu, çıktıların belleğe kaydedilmesini içerir )
Not: Ayrıca her hastanın 3D taramasından bir grup dilim bir araya getirilebilir. Ancak dilime bağlı 3B dizileri de döndürmek istersek (örneğin, etkileşimli iyileştirme ağları ( bu çalışmanın Şekil 1'ine bakın ), o zaman bu, veri yükleyicinizin bellek ayak izini büyük ölçüde artırır.
torch.multiprocessing
kullanmaHasta taramalarının birden fazla okunmasını önlemek için, ideal olarak her hastanın ( 8 hasta olduğunu varsayalım ) belirli bir çalışan tarafından okunmasına ihtiyacımız var.
Bunu başarmak için torch veri yükleyici sınıfıyla aynı dahili araçları kullanırız (örn. torch.multiprocessing()
), ancak küçük bir farkla. Özel veri yükleyicimiz myDataloader
için aşağıdaki iş akışı şeklini ve kodunu kontrol edin
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
Yukarıdaki kod parçası ( bunun yerine 8 hastayla birlikte ) aşağıdaki işlevleri içerir
__iter__()
- myDataloader()
bir döngü olduğundan, gerçekte üzerinde döngü yaptığı işlev budur.
_initWorkers()
- Burada, bireysel giriş kuyrukları ile işçi süreçlerimizi workerInputQueues[workerId]
oluştururuz. Bu, sınıf başlatıldığında çağrılır.
fillInputQueues()
- Bu işlev döngüye başladığımızda çağrılır ( esasen her çağın başlangıcında ). Bireysel çalışanın giriş kuyruğunu doldurur.
getSlice()
- Bu, hasta hacminden bir dilim döndüren ana mantık işlevidir. Buradaki kodu kontrol edin.
collate_tensor_fn()
- Bu işlev doğrudan torch deposundan - torchv1.12.0 kopyalanır ve verileri bir araya toplamak için kullanılır.Veri yükleyicimizin varsayılan seçeneğe kıyasla bir hızlanma sunup sunmadığını test etmek için, her bir veri yükleyici döngüsünün hızını farklı çalışan sayıları kullanarak test ederiz. Deneylerimizde iki parametreyi değiştirdik:
Önce oyuncak veri kümemizle denemeler yapıyoruz ve veri yükleyicimizin çok daha hızlı performans gösterdiğini görüyoruz. Aşağıdaki şekle bakın (veya bu kodla çoğaltın)
Burada aşağıdakileri görebiliriz
patientSlicesList
değişkeninin hasta başına 5 dilim içermesidir. Bu nedenle çalışanın, serinin son indeksine eklenecek ikinci hastayı okumayı beklemesi gerekiyor. Daha sonra 3 boyutlu taramaların yüklendiği, bir dilimin çıkarıldığı gerçek bir veri kümesini karşılaştırıyoruz.
Bunu gözlemledik
Ayrıca değişen çalışan sayılarıyla veri yükleme sırasında kaynak kullanımını da izledik. Çalışan sayısının artmasıyla, ek süreçlerin getirdiği paralellik nedeniyle beklenen CPU ve bellek kullanımının arttığını gözlemledik. Kullanıcılar, optimum çalışan sayısını seçerken donanım kısıtlamalarını ve kaynak kullanılabilirliğini dikkate almalıdır.
Bu blog yazısında, büyük 3D tıbbi taramalar içeren veri kümeleriyle uğraşırken PyTorch'un standart DataLoader'ının sınırlamalarını araştırdık ve veri yükleme verimliliğini artırmak için torch.multiprocessing
kullanarak özel bir çözüm sunduk.
Bu 3 boyutlu tıbbi taramalardan dilim çıkarma bağlamında, varsayılan dataLoader, çalışanların belleği paylaşmaması nedeniyle potansiyel olarak aynı hasta taramasının birden fazla okunmasına yol açabilir. Bu fazlalık, özellikle büyük veri kümeleriyle uğraşırken önemli gecikmelere neden olur.
Özel dataLoader'ımız hastaları çalışanlara bölerek her 3D taramanın çalışan başına yalnızca bir kez okunmasını sağlar. Bu yaklaşım, gereksiz disk okumalarını önler ve veri yüklemeyi hızlandırmak için paralel işlemeden yararlanır.
Performans testleri, özel dataLoader'ımızın, özellikle daha küçük parti boyutları ve birden fazla çalışan işlemiyle genel olarak standart dataLoader'dan daha iyi performans gösterdiğini gösterdi.
Özel dataLoader'ımız, gereksiz okumaları azaltarak ve paralelliği maksimuma çıkararak büyük 3D tıbbi veri kümeleri için veri yükleme verimliliğini artırır. Bu iyileştirme, daha hızlı eğitim sürelerine ve donanım kaynaklarının daha iyi kullanılmasına yol açabilir.
Bu blog meslektaşım Jingnan Jia ile birlikte yazıldı.