Using Autodiff to Estimate Posterior Moments, Marginals and Samples: Related Workby@bayesianinference

Using Autodiff to Estimate Posterior Moments, Marginals and Samples: Related Work

by Bayesian InferenceApril 15th, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Importance weighting allows us to reweight samples drawn from a proposal in order to compute expectations of a different distribution.
featured image - Using Autodiff to Estimate Posterior Moments, Marginals and Samples: Related Work
Bayesian Inference HackerNoon profile picture

This paper is available on arxiv under CC 4.0 license.


(1) Sam Bowyer, Equal contribution, Department of Mathematics and [email protected];

(2) Thomas Heap, Equal contribution, Department of Computer Science University of Bristol and [email protected];

(3) Laurence Aitchison, Department of Computer Science University of Bristol and [email protected].

There is a considerable body of work in the discrete graphical model setting that computes posterior expectations, marginals and samples (Dawid 1992; Pfeffer 2005; Bidyuk and Dechter 2007; Geldenhuys, Dwyer, and Visser 2012; Gogate and Dechter 2012; Claret et al. 2013; Sankaranarayanan, Chakarov, and Gulwani 2013; Goodman and Stuhlmuller 2014; Gehr, Misailovic, and Vechev 2016; ¨ Narayanan et al. 2016; Albarghouthi et al. 2017; Dechter et al. 2018; Wang, Hoffmann, and Reps 2018; Obermeyer et al. 2019; Holtzen, Van den Broeck, and Millstein 2020). Our work differs in two respects. First, our massively parallel methods are not restricted to discrete graphical mod els, but can operate with arbitrary continuous latent variables and graphs with a mixture of continuous and discrete latent variables. Second, this prior work involves complex implementations that, in one sense or another, “proceed by recording an adjoint compute graph alongside the forward computation and then traversing the adjoint graph backwards starting from the final result of the forward computation” (Obermeyer et al. 2019). The forward computation is reasonably straightforward: it is just a big tensor product that can be computed efficiently using pre-existing libraries such as opt-einsum, and results in (an estimate of) the marginal likelihood. However, the backward traversal is much more complex, if for no other reason than the need to implement separate traversals for each operation of interest (computing posterior expectations, marginals and samples). Additionally, these traversals need to correctly handle all special cases, including optimized implementations of plates and timeseries. Importantly, optimizing the forward computation is usually quite straightforward while implementing an optimized backward traversal is far more complex. For instance, the forward computation for a timeseries involves a product of T matrices arranged in a chain. Naively computing this product on GPUs is very slow, as it requires T separate matrix multiplications. However, it is possible to massively optimize this forward computation, converting O(T) to O(log(T)) tensor operations by multiplying adjacent pairs of matrices in a single batched matrix multiplication operation. This optimization is straightforward in the forward computation. However, applying this optimization as part of the backward computation is far more complex (see Corenflos, Chopin, and Sarkk ¨ a 2022 for details). This ¨ complexity (along with similar complexity for other important optimizations such as plates) was prohibitive for academic teams implementing e.g. new probabilistic programming languages. Our key contribution is thus to provide a much simpler approach to directly compute posterior expectations, marginals and samples by differentiating through the forward computation, without having to hand-write and hand-optimize backward traversals.

There is work on fitting importance weighted autoencoders (IWAE; Burda, Grosse, and Salakhutdinov 2015) and reweighted wake-sleep (RWS; Bornschein and Bengio 2014; Le et al. 2020) in the massively parallel setting (Aitchison 2019; Geffner and Domke 2022; Heap and Laurence 2023) for general probabilistic models. However, this work only provides methods for performing massively parallel updates to approximate posteriors (e.g. by optimizing a massively parallel ELBO). This work does not provide a method to individually reweight the samples to provide accurate posterior expectations, marginals and samples. Instead, this previous work simply takes the learned approximate posterior as an estimate of the true posterior, and does not attempt to correct for inevitable biases.

Critically, our key contribution is not the massively parallel importance sampling method itself, which we acknowledge does bear similarities to e.g. particle filtering/SMC methods (Gordon, Salmond, and Smith 1993; Doucet, Johansen et al. 2009; Andrieu, Doucet, and Holenstein 2010; Maddison et al. 2017; Le et al. 2017; Lindsten et al. 2017; Naesseth et al. 2018; Kuntz, Crucinio, and Johansen 2023; Lai, Domke, and Sheldon 2022; Crucinio and Johansen 2023) that have been generalised to arbitrary graphical models and where the resampling step has been eliminated. Instead, our key contribution is the simple method for computing posterior expectations, marginals and samples without requiring the implementation of complex backwards traversals, and this has not appeared in past work.