PyTorch का DataLoader ( torch.utils.data.Dataloader
) डीप लर्निंग मॉडल को प्रशिक्षित करने के लिए डेटा को कुशलतापूर्वक लोड करने और प्रीप्रोसेस करने के लिए पहले से ही एक उपयोगी उपकरण है। डिफ़ॉल्ट रूप से, PyTorch एकल-कार्यकर्ता प्रक्रिया ( num_workers=0
) का उपयोग करता है, लेकिन उपयोगकर्ता समानांतरता का लाभ उठाने और डेटा लोडिंग को गति देने के लिए एक उच्च संख्या निर्दिष्ट कर सकते हैं।
हालाँकि, चूँकि यह एक सामान्य-उद्देश्य वाला डेटा लोडर है, और भले ही यह समानांतरीकरण प्रदान करता है, फिर भी यह कुछ कस्टम उपयोग मामलों के लिए उपयुक्त नहीं है। इस पोस्ट में, हम यह पता लगाते हैं कि हम torch.multiprocessing()
का उपयोग करके 3D मेडिकल स्कैन के डेटासेट से कई 2D स्लाइस को लोड करने की गति कैसे बढ़ा सकते हैं।
torch.utils.data.Dataset
मैं एक ऐसे उपयोग के मामले की कल्पना करता हूँ जिसमें रोगियों के लिए 3D स्कैन का एक सेट दिया गया है (यानी, P1, P2, P3, ...) और संबंधित स्लाइस की एक सूची; हमारा लक्ष्य एक डेटालोडर बनाना है जो हर पुनरावृत्ति में एक स्लाइस आउटपुट करता है । नीचे दिए गए पायथन कोड को देखें जहाँ हम myDataset
नामक एक टॉर्च डेटासेट बनाते हैं, और इसे torch.utils.data.Dataloader()
में पास करते हैं।
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
हमारे उपयोग के मामले में मुख्य चिंता यह है कि 3D मेडिकल स्कैन आकार में बड़े हैं ( यहाँ time.sleep()
ऑपरेशन द्वारा अनुकरण किया गया है) और इसलिए
डिस्क से उन्हें पढ़ने में समय लग सकता है
और अधिकांश मामलों में 3D स्कैन का एक बड़ा डेटासेट मेमोरी में पहले से पढ़ा नहीं जा सकता है
आदर्श रूप से, हमें प्रत्येक रोगी स्कैन को उससे जुड़े सभी स्लाइस के लिए केवल एक बार पढ़ना चाहिए। लेकिन चूंकि डेटा बैच आकार के आधार पर torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
द्वारा कार्यकर्ताओं में विभाजित किया जाता है, इसलिए अलग-अलग कार्यकर्ताओं द्वारा एक रोगी को दो बार पढ़ने की संभावना है ( नीचे दी गई छवि और लॉग देखें )।
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
संक्षेप में, torch.utils.data.Dataloader
के मौजूदा कार्यान्वयन के साथ समस्याएं यहां दी गई हैं
myDataset()
की एक प्रति दी जाती है (संदर्भ:
patientSliceList
( नीचे छवि देखें ) पर लूप करता है, इसलिए (रोगी आईडी, स्लाइस आईडी) कॉम्बो के बीच कोई प्राकृतिक फेरबदल संभव नहीं है। ( नोट: कोई फेरबदल कर सकता है, लेकिन इसमें मेमोरी में आउटपुट संग्रहीत करना शामिल है )
नोट: कोई भी व्यक्ति प्रत्येक मरीज के 3D स्कैन से एक साथ कई स्लाइस लौटा सकता है। लेकिन अगर हम स्लाइस-निर्भर 3D एरे (उदाहरण के लिए, इंटरैक्टिव रिफाइनमेंट नेटवर्क ( इस कार्य का चित्र 1 देखें ) भी लौटाना चाहते हैं, तो इससे आपके डेटा लोडर की मेमोरी फ़ुटप्रिंट बहुत बढ़ जाती है।
torch.multiprocessing
उपयोग करनारोगी स्कैन को एक से अधिक बार पढ़ने से रोकने के लिए, हमें आदर्श रूप से प्रत्येक रोगी ( मान लीजिए 8 रोगी हैं ) को एक विशेष कार्यकर्ता द्वारा पढ़ने की आवश्यकता होगी।
इसे प्राप्त करने के लिए, हम टॉर्च डेटालोडर क्लास (यानी, torch.multiprocessing()
) के समान आंतरिक उपकरणों का उपयोग करते हैं, लेकिन थोड़े अंतर के साथ। हमारे कस्टम डेटालोडर - myDataloader
के लिए नीचे वर्कफ़्लो चित्र और कोड देखें
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
उपरोक्त स्निपेट ( जिसमें 8 मरीज हैं ) में निम्नलिखित कार्य शामिल हैं
__iter__()
- चूंकि myDataloader()
एक लूप है, यह वह फ़ंक्शन है जिस पर यह वास्तव में लूप करता है।
_initWorkers()
- यहाँ, हम अपने कार्यकर्ता प्रक्रियाओं को उनके व्यक्तिगत इनपुट कतारों workerInputQueues[workerId]
के साथ बनाते हैं। जब क्लास को आरंभीकृत किया जाता है तो इसे कॉल किया जाता है।
fillInputQueues()
- यह फ़ंक्शन तब कॉल किया जाता है जब हम लूप शुरू करते हैं ( अनिवार्य रूप से प्रत्येक युग की शुरुआत में )। यह व्यक्तिगत कार्यकर्ता की इनपुट कतार को भरता है।
getSlice()
- यह मुख्य लॉजिक फ़ंक्शन है जो रोगी वॉल्यूम से स्लाइस लौटाता है। कोड यहाँ देखें।
collate_tensor_fn()
- यह फ़ंक्शन सीधे torch repo - torchv1.12.0 से कॉपी किया गया है और डेटा को एक साथ बैच करने के लिए उपयोग किया जाता है।यह जांचने के लिए कि क्या हमारा डेटा लोडर डिफ़ॉल्ट विकल्प की तुलना में गति प्रदान करता है, हम अलग-अलग वर्कर काउंट का उपयोग करके प्रत्येक डेटा लोडर लूप की गति का परीक्षण करते हैं। हमने अपने प्रयोगों में दो मापदंडों को बदला:
हम पहले अपने टॉय डेटासेट के साथ प्रयोग करते हैं और देखते हैं कि हमारा डेटालोडर बहुत तेज़ काम करता है। नीचे दिया गया चित्र देखें (या इस कोड के साथ पुन: प्रस्तुत करें)
यहाँ हम निम्नलिखित देख सकते हैं
patientSlicesList
चर में प्रति रोगी 5 स्लाइस होते हैं। इसलिए, कार्यकर्ता को बैच के अंतिम इंडेक्स में जोड़ने के लिए दूसरे रोगी को पढ़ने के लिए प्रतीक्षा करने की आवश्यकता होती है। फिर हम एक वास्तविक डेटासेट का बेंचमार्क बनाते हैं, जहां 3D स्कैन लोड किए जाते हैं, एक स्लाइस निकाली जाती है,
हमने देखा कि
हमने अलग-अलग वर्कर काउंट के साथ डेटा लोडिंग के दौरान संसाधन उपयोग की निगरानी भी की। वर्कर की अधिक संख्या के साथ, हमने CPU और मेमोरी उपयोग में वृद्धि देखी, जो अतिरिक्त प्रक्रियाओं द्वारा शुरू की गई समानांतरता के कारण अपेक्षित है। उपयोगकर्ताओं को इष्टतम वर्कर काउंट चुनते समय अपनी हार्डवेयर बाधाओं और संसाधन उपलब्धता पर विचार करना चाहिए।
इस ब्लॉग पोस्ट में, हमने बड़े 3D मेडिकल स्कैन वाले डेटासेट से निपटने के दौरान PyTorch के मानक DataLoader की सीमाओं का पता लगाया और डेटा लोडिंग दक्षता में सुधार करने के लिए torch.multiprocessing
का उपयोग करके एक कस्टम समाधान प्रस्तुत किया।
इन 3D मेडिकल स्कैन से स्लाइस निष्कर्षण के संदर्भ में, डिफ़ॉल्ट डेटा लोडर संभावित रूप से एक ही रोगी स्कैन के कई रीड्स की ओर ले जा सकता है क्योंकि वर्कर मेमोरी साझा नहीं करते हैं। यह अतिरेक महत्वपूर्ण देरी का कारण बनता है, खासकर जब बड़े डेटासेट से निपटना होता है।
हमारा कस्टम डेटा लोडर मरीजों को वर्कर्स के बीच विभाजित करता है, यह सुनिश्चित करता है कि प्रत्येक 3D स्कैन को प्रत्येक वर्कर के लिए केवल एक बार पढ़ा जाए। यह दृष्टिकोण अनावश्यक डिस्क रीड को रोकता है और डेटा लोडिंग को गति देने के लिए समानांतर प्रसंस्करण का लाभ उठाता है।
प्रदर्शन परीक्षण से पता चला कि हमारा कस्टम डेटा लोडर आम तौर पर मानक डेटा लोडर से बेहतर प्रदर्शन करता है, विशेष रूप से छोटे बैच आकार और एकाधिक कार्यकर्ता प्रक्रियाओं के साथ।
हमारा कस्टम डेटा लोडर अनावश्यक रीड्स को कम करके और समानांतरता को अधिकतम करके बड़े 3D मेडिकल डेटासेट के लिए डेटा लोडिंग दक्षता को बढ़ाता है। इस सुधार से प्रशिक्षण समय में तेज़ी आ सकती है और हार्डवेयर संसाधनों का बेहतर उपयोग हो सकता है।
यह ब्लॉग मैंने अपनी सहकर्मी जिंगनान जिया के साथ मिलकर लिखा है।