paint-brush
The Nuts and Bolts of Parallel-UNet: Implementation Detailsby@backpropagation

The Nuts and Bolts of Parallel-UNet: Implementation Details

by BackpropagationOctober 6th, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

The implementation of TryOnDiffusion with Parallel-UNet includes a 256x256 architecture with key changes for improved performance. Trained using JAX on TPU-v4 for 500K iterations, the inference process is efficient, taking around 18 seconds for a batch of four.
featured image - The Nuts and Bolts of Parallel-UNet: Implementation Details
Backpropagation HackerNoon profile picture

Authors:

(1) Luyang Zhu, University of Washington and Google Research, and work done while the author was an intern at Google;

(2) Dawei Yang, Google Research;

(3) Tyler Zhu, Google Research;

(4) Fitsum Reda, Google Research;

(5) William Chan, Google Research;

(6) Chitwan Saharia, Google Research;

(7) Mohammad Norouzi, Google Research;

(8) Ira Kemelmacher-Shlizerman, University of Washington and Google Research.

Abstract and 1. Introduction

2. Related Work

3. Method

3.1. Cascaded Diffusion Models for Try-On

3.2. Parallel-UNet

4. Experiments

5. Summary and Future Work and References


Appendix

A. Implementation Details

B. Additional Results

A. Implementation Details

A.1. Parallel-UNet

A.2. Training and Inference

TryOnDiffusion was implemented in JAX [4]. All three diffusion models are trained on 32 TPU-v4 chips for 500K iterations (around 3 days for each diffusion model). After trained, we run the inference of the whole pipeline on 4 TPU-v4 chips with batch size 4, which takes around 18 seconds for one batch.


This paper is available on arxiv under CC BY-NC-ND 4.0 DEED license.