O DataLoader do PyTorch ( torch.utils.data.Dataloader
) já é uma ferramenta útil para carregar e pré-processar dados com eficiência para treinar modelos de aprendizado profundo. Por padrão, PyTorch usa um processo de trabalho único ( num_workers=0
), mas os usuários podem especificar um número maior para aproveitar o paralelismo e acelerar o carregamento de dados.
No entanto, por ser um dataloader de uso geral e, embora ofereça paralelização, ainda não é adequado para determinados casos de uso personalizados. Nesta postagem, exploramos como podemos acelerar o carregamento de vários cortes 2D de um conjunto de dados de exames médicos 3D usando torch.multiprocessing()
.
torch.utils.data.Dataset
Imagino um caso de uso em que seja fornecido um conjunto de varreduras 3D para pacientes (ou seja, P1, P2, P3,…) e uma lista de cortes correspondentes; nosso objetivo é construir um dataloader que produza uma fatia em cada iteração . Verifique o código Python abaixo, onde construímos um conjunto de dados de tocha chamado myDataset
e passamos para torch.utils.data.Dataloader()
.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
A principal preocupação com nosso caso de uso é que as varreduras médicas 3D são grandes ( emuladas aqui pela operação time.sleep()
) e, portanto,
lê-los do disco pode consumir muito tempo
e um grande conjunto de dados de digitalizações 3D, na maioria dos casos, não pode ser pré-lido na memória
Idealmente, deveríamos ler cada exame do paciente apenas uma vez para todos os cortes associados a ele. Mas como os dados são divididos por torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
em trabalhadores dependendo do tamanho do lote, existe a possibilidade de diferentes trabalhadores lerem um paciente duas vezes ( verifique a imagem e registre abaixo ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Para resumir, aqui estão os problemas com a implementação existente de torch.utils.data.Dataloader
myDataset()
(Ref:
patientSliceList
( veja a imagem abaixo ), nenhum embaralhamento natural é possível entre os combos (pacienteId, sliceId). ( Nota: pode-se embaralhar, mas isso envolve armazenar as saídas na memória )
Nota: Também é possível retornar um monte de fatias da digitalização 3D de cada paciente. Mas se desejarmos também retornar matrizes 3D dependentes de fatia (por exemplo, redes de refinamento interativo ( veja a Figura 1 deste trabalho ), isso aumentará muito o consumo de memória do seu carregador de dados.
torch.multiprocessing
Para evitar leituras múltiplas de exames de pacientes , idealmente precisaríamos que cada paciente ( vamos imaginar 8 pacientes ) fosse lido por um funcionário específico.
Para conseguir isso, usamos as mesmas ferramentas internas da classe do dataloader torch (ou seja, torch.multiprocessing()
), mas com uma pequena diferença. Verifique a figura e o código do fluxo de trabalho abaixo para nosso dataloader personalizado - myDataloader
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
O trecho acima ( com 8 pacientes ) contém as seguintes funções
__iter__()
- Como myDataloader()
é um loop, esta é a função sobre a qual ele realmente faz o loop.
_initWorkers()
- Aqui, criamos nossos processos de trabalho com suas filas de entrada individuais workerInputQueues[workerId]
. Isso é chamado quando a classe é inicializada.
fillInputQueues()
- Esta função é chamada quando iniciamos o loop ( essencialmente no início de cada época ). Ele preenche a fila de entrada do trabalhador individual.
getSlice()
- Esta é a função lógica principal que retorna uma fatia do volume do paciente. Verifique o código aqui .
collate_tensor_fn()
- Esta função é copiada diretamente do repositório torch - torchv1.12.0 e é usada para agrupar dados em lote.Para testar se nosso dataloader oferece aceleração em comparação com a opção padrão, testamos a velocidade de cada loop do dataloader usando diferentes contagens de trabalhadores . Variamos dois parâmetros em nossos experimentos:
Primeiro experimentamos nosso conjunto de dados de brinquedo e vemos que nosso carregador de dados funciona muito mais rápido. Veja a figura abaixo (ou reproduza com este código )
Aqui podemos ver o seguinte
patientSlicesList
contém 5 fatias por paciente. Assim, o trabalhador precisa aguardar a leitura do segundo paciente para adicionar ao último índice do lote. Em seguida, comparamos um conjunto de dados real onde as varreduras 3D são carregadas, uma fatia é extraída,
Observamos que
Também monitoramos a utilização de recursos durante o carregamento de dados com contagens variadas de trabalhadores. Com um maior número de trabalhadores, observamos um aumento no uso de CPU e memória, o que é esperado devido ao paralelismo introduzido por processos adicionais. Os usuários devem considerar as restrições de hardware e a disponibilidade de recursos ao escolher a contagem ideal de trabalhadores.
Nesta postagem do blog, exploramos as limitações do DataLoader padrão do PyTorch ao lidar com conjuntos de dados contendo grandes exames médicos 3D e apresentamos uma solução personalizada usando torch.multiprocessing
para melhorar a eficiência do carregamento de dados.
No contexto da extração de fatias desses exames médicos 3D, o dataLoader padrão pode potencialmente levar a múltiplas leituras do mesmo exame do paciente, pois os trabalhadores não compartilham memória. Esta redundância causa atrasos significativos, especialmente quando se trata de grandes conjuntos de dados.
Nosso dataLoader personalizado divide os pacientes entre os trabalhadores, garantindo que cada digitalização 3D seja lida apenas uma vez por trabalhador. Essa abordagem evita leituras redundantes de disco e aproveita o processamento paralelo para acelerar o carregamento de dados.
Os testes de desempenho mostraram que nosso dataLoader personalizado geralmente supera o dataLoader padrão, especialmente com lotes menores e vários processos de trabalho.
Nosso dataLoader personalizado melhora a eficiência do carregamento de dados para grandes conjuntos de dados médicos 3D, reduzindo leituras redundantes e maximizando o paralelismo. Essa melhoria pode levar a tempos de treinamento mais rápidos e melhor utilização dos recursos de hardware.
Este blog foi escrito em conjunto com meu colega Jingnan Jia .