paint-brush
Como melhorar a paralelização de dataloaders Torch usando Torch.multiprocessingpor@pixelperfectionist
548 leituras
548 leituras

Como melhorar a paralelização de dataloaders Torch usando Torch.multiprocessing

por Prerak Mody13m2024/06/10
Read on Terminal Reader

Muito longo; Para ler

O dataloader PyTorch é uma ferramenta para carregar e pré-processar dados com eficiência para treinar modelos de aprendizado profundo. Nesta postagem, exploramos como podemos acelerar esse processo usando nosso dataloader personalizado junto com torch.multiprocessing. Experimentamos carregar vários cortes 2D de um conjunto de dados de exames médicos 3D.
featured image - Como melhorar a paralelização de dataloaders Torch usando Torch.multiprocessing
Prerak Mody HackerNoon profile picture
0-item

Introdução

O DataLoader do PyTorch ( torch.utils.data.Dataloader ) já é uma ferramenta útil para carregar e pré-processar dados com eficiência para treinar modelos de aprendizado profundo. Por padrão, PyTorch usa um processo de trabalho único ( num_workers=0 ), mas os usuários podem especificar um número maior para aproveitar o paralelismo e acelerar o carregamento de dados.


No entanto, por ser um dataloader de uso geral e, embora ofereça paralelização, ainda não é adequado para determinados casos de uso personalizados. Nesta postagem, exploramos como podemos acelerar o carregamento de vários cortes 2D de um conjunto de dados de exames médicos 3D usando torch.multiprocessing() .


Desejamos extrair um conjunto de fatias da digitalização 3D de cada paciente. Esses pacientes fazem parte de um grande conjunto de dados.



Nosso torch.utils.data.Dataset

Imagino um caso de uso em que seja fornecido um conjunto de varreduras 3D para pacientes (ou seja, P1, P2, P3,…) e uma lista de cortes correspondentes; nosso objetivo é construir um dataloader que produza uma fatia em cada iteração . Verifique o código Python abaixo, onde construímos um conjunto de dados de tocha chamado myDataset e passamos para torch.utils.data.Dataloader() .


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


A principal preocupação com nosso caso de uso é que as varreduras médicas 3D são grandes ( emuladas aqui pela operação time.sleep() ) e, portanto,

  • lê-los do disco pode consumir muito tempo

  • e um grande conjunto de dados de digitalizações 3D, na maioria dos casos, não pode ser pré-lido na memória


Idealmente, deveríamos ler cada exame do paciente apenas uma vez para todos os cortes associados a ele. Mas como os dados são divididos por torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) em trabalhadores dependendo do tamanho do lote, existe a possibilidade de diferentes trabalhadores lerem um paciente duas vezes ( verifique a imagem e registre abaixo ).

O Torch divide o carregamento do conjunto de dados em cada trabalhador dependendo do tamanho do lote (=3, neste caso). Devido a isso, cada paciente é lido por vários trabalhadores.


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


Para resumir, aqui estão os problemas com a implementação existente de torch.utils.data.Dataloader

  • Cada um dos trabalhadores recebe uma cópia do myDataset() (Ref: tocha v1.2. 0 ), e como não possuem memória compartilhada, isso leva a uma leitura dupla do disco da digitalização 3D de um paciente.


  • Além disso, como a tocha percorre sequencialmente patientSliceList ( veja a imagem abaixo ), nenhum embaralhamento natural é possível entre os combos (pacienteId, sliceId). ( Nota: pode-se embaralhar, mas isso envolve armazenar as saídas na memória )


O torch.utils.data.Dataloader() padrão possui uma fila interna que gerencia globalmente como as saídas são extraídas dos trabalhadores. Mesmo que os dados estejam prontos por um determinado trabalhador, ele não poderá produzi-los, pois deve respeitar essa fila global.



Nota: Também é possível retornar um monte de fatias da digitalização 3D de cada paciente. Mas se desejarmos também retornar matrizes 3D dependentes de fatia (por exemplo, redes de refinamento interativo ( veja a Figura 1 deste trabalho ), isso aumentará muito o consumo de memória do seu carregador de dados.



Usando torch.multiprocessing

Para evitar leituras múltiplas de exames de pacientes , idealmente precisaríamos que cada paciente ( vamos imaginar 8 pacientes ) fosse lido por um funcionário específico.

Aqui, cada trabalhador está focado na leitura de um (conjunto de) paciente(s).


Para conseguir isso, usamos as mesmas ferramentas internas da classe do dataloader torch (ou seja, torch.multiprocessing() ), mas com uma pequena diferença. Verifique a figura e o código do fluxo de trabalho abaixo para nosso dataloader personalizado - myDataloader

Aqui, a fila de saída (parte inferior) contém as saídas de cada trabalhador. Cada trabalhador recebe informações de entrada (fila de entrada mostrada na parte superior) apenas para um conjunto específico de pacientes. Assim, isso evita múltiplas leituras da digitalização 3D de um paciente.



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


O trecho acima ( com 8 pacientes ) contém as seguintes funções

  • __iter__() - Como myDataloader() é um loop, esta é a função sobre a qual ele realmente faz o loop.


  • _initWorkers() - Aqui, criamos nossos processos de trabalho com suas filas de entrada individuais workerInputQueues[workerId] . Isso é chamado quando a classe é inicializada.


  • fillInputQueues() - Esta função é chamada quando iniciamos o loop ( essencialmente no início de cada época ). Ele preenche a fila de entrada do trabalhador individual.


  • getSlice() - Esta é a função lógica principal que retorna uma fatia do volume do paciente. Verifique o código aqui .


  • collate_tensor_fn() - Esta função é copiada diretamente do repositório torch - torchv1.12.0 e é usada para agrupar dados em lote.


Desempenho

Para testar se nosso dataloader oferece aceleração em comparação com a opção padrão, testamos a velocidade de cada loop do dataloader usando diferentes contagens de trabalhadores . Variamos dois parâmetros em nossos experimentos:


  • Número de trabalhadores : testamos 1, 2, 4 e 8 processos de trabalho.
  • Tamanho do batch : Avaliamos diferentes tamanhos de lote variando de 1 a 8.

Conjunto de dados de brinquedos

Primeiro experimentamos nosso conjunto de dados de brinquedo e vemos que nosso carregador de dados funciona muito mais rápido. Veja a figura abaixo (ou reproduza com este código )
Menor tempo total e maiores iterações/s significam um melhor carregador de dados.

Aqui podemos ver o seguinte

  • Ao usar um único trabalhador, ambos os dataloaders são iguais.


  • Ao usar trabalhadores adicionais (ou seja, 2,4,8), há uma aceleração em ambos os dataloaders, no entanto, a aceleração é muito maior em nosso dataloader personalizado.


  • Ao usar um tamanho de lote de 6 (em comparação com 1,2,3,4), há um pequeno impacto no desempenho. Isso ocorre porque, em nosso conjunto de dados de brinquedo, a variável patientSlicesList contém 5 fatias por paciente. Assim, o trabalhador precisa aguardar a leitura do segundo paciente para adicionar ao último índice do lote.

Conjunto de dados do mundo real

Em seguida, comparamos um conjunto de dados real onde as varreduras 3D são carregadas, uma fatia é extraída, algum pré-processamento adicional é feito e, em seguida, a fatia e outras matrizes serão retornadas. Veja a figura abaixo para resultados.


Observamos que aumentar o número de processos de trabalho (e tamanhos de lote) geralmente levou a um carregamento de dados mais rápido e, portanto, pode levar a um treinamento mais rápido. Para lotes menores (por exemplo, 1 ou 2), duplicar o número de trabalhadores resultou em acelerações muito maiores. No entanto, à medida que o tamanho do lote aumentou, a melhoria marginal resultante da adição de mais trabalhadores diminuiu.

Quanto maiores as iterações/s, mais rápido será o carregador de dados.

Utilização de recursos

Também monitoramos a utilização de recursos durante o carregamento de dados com contagens variadas de trabalhadores. Com um maior número de trabalhadores, observamos um aumento no uso de CPU e memória, o que é esperado devido ao paralelismo introduzido por processos adicionais. Os usuários devem considerar as restrições de hardware e a disponibilidade de recursos ao escolher a contagem ideal de trabalhadores.

Resumo

  1. Nesta postagem do blog, exploramos as limitações do DataLoader padrão do PyTorch ao lidar com conjuntos de dados contendo grandes exames médicos 3D e apresentamos uma solução personalizada usando torch.multiprocessing para melhorar a eficiência do carregamento de dados.


  2. No contexto da extração de fatias desses exames médicos 3D, o dataLoader padrão pode potencialmente levar a múltiplas leituras do mesmo exame do paciente, pois os trabalhadores não compartilham memória. Esta redundância causa atrasos significativos, especialmente quando se trata de grandes conjuntos de dados.


  3. Nosso dataLoader personalizado divide os pacientes entre os trabalhadores, garantindo que cada digitalização 3D seja lida apenas uma vez por trabalhador. Essa abordagem evita leituras redundantes de disco e aproveita o processamento paralelo para acelerar o carregamento de dados.


  4. Os testes de desempenho mostraram que nosso dataLoader personalizado geralmente supera o dataLoader padrão, especialmente com lotes menores e vários processos de trabalho.


    1. No entanto, os ganhos de desempenho diminuíram com lotes maiores.


Nosso dataLoader personalizado melhora a eficiência do carregamento de dados para grandes conjuntos de dados médicos 3D, reduzindo leituras redundantes e maximizando o paralelismo. Essa melhoria pode levar a tempos de treinamento mais rápidos e melhor utilização dos recursos de hardware.


Este blog foi escrito em conjunto com meu colega Jingnan Jia .