Efficient Neural Network Approaches: Partially Convex Potential Maps (PCP-Map) for Conditional OTby@bayesianinference
135 reads

Efficient Neural Network Approaches: Partially Convex Potential Maps (PCP-Map) for Conditional OT

by Bayesian InferenceApril 15th, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

This paper presents two neural network approaches that approximate the solutions of static and dynamic conditional optimal transport problems, respectively.
featured image - Efficient Neural Network Approaches: Partially Convex Potential Maps (PCP-Map) for Conditional OT
Bayesian Inference HackerNoon profile picture

This paper is available on arxiv under CC 4.0 license.


(1) Zheyu Oliver Wang, Department of Aeronautics and Astronautics, Massachusetts Institute of Technology, Cambridge, MA and [email protected];

(2) Ricardo Baptista, Computing + Mathematical Sciences, California Institute of Technology, Pasadena, CA and [email protected];

(3) Youssef Marzouk, Department of Aeronautics and Astronautics, Massachusetts Institute of Technology, Cambridge, MA and [email protected];

(4) Lars Ruthotto, Department of Mathematics, Emory University, Atlanta, GA and [email protected];

(5) Deepanshu Verma, Department of Mathematics, Emory University, Atlanta, GA and [email protected].

3. Partially Convex Potential Maps (PCP-Map) for Conditional OT.

Our first approach, PCP-Map, approximately solves the static COT problem through maximum likelihood training over a set of partially monotone maps. Motivated by the structure of transport maps that are optimal with respect to the quadratic cost function, we parameterize these maps as the gradients of scalar-valued neural networks that are strictly convex in x for any choice of the weights.

Training problem. Given samples from the joint distribution, our algorithm seeks a conditional

generator g by solving the maximum likelihood problem

The objective JNLL is the expected negative log-likelihood functional

which agrees up to an additive constant with the negative logarithm of (2.1), and M is the set of

maps that are monotonically increasing in their first argument, i.e.,

We note that JNLL also agrees (up to an additive constant) with the Kullback-Leibler (KL) divergence from the push forward of ρz through the generator g to the conditional distribution in

expectation over the conditioning variable.

Since the learning problem in (3.2) only involves the inverse generator g −1, we seek this map directly in the same space of monotone functions M in (3.3). This avoids inverting the generator during training. The drawback is that sampling the target conditional requires inverting the learned map. In this work, we will find the conditional generator g −1 (·, y) that pushes forward the conditional distribution π(x|y) to the reference distribution for each y.

Limiting the search to monotone maps is motivated by the celebrated Brenier’s theorem, which

ensures there exists a unique monotone map g −1 such that g −1 (·, y)♯π(x|y) = g(·, y) ♯π(x|y) = ρz among all maps written as the gradient of a convex potential; see [8] for the original result and [10] for conditional transport maps. Theorem 2.3 in [10] also shows that g −1 is optimal in the sense that among all maps that match the distributions, it minimizes the integrated L2 transport costs

Neural Network Representation. In this work, we leverage the structural form of the conditional

Brenier map by using partially input convex neural networks (PICNNs) [2] to express the inverse

generator directly. In particular, we parameterize g −1 as the gradient of a PICNN G˜ θ : R n×R m → R that depends on weights θ. A PICNN is a feed-forward neural network that is specifically designed to ensure convexity in some of its inputs; see the original work that introduced this neural network architecture in [2] and its use for generative modeling in [20]. To the best of our knowledge, investigating if PICNNs are universal approximators of partially input convex functions is still an open issue, also see [23, Section IV], that is beyond the scope of our paper; a perhaps related result for fully input convex neural networks is given in [20, Appendix C].

To ensure the monotonicity of g −1, we construct G˜ to be strictly convex as a linear combination of a PICNN and a positive definite quadratic term. That is,

where γ1, γ2, and γ3 are scalar parameters that are re-parameterized via the soft-plus function ψ(x) = log(1 + exp(x)) and the ReLU function σReLU(x) = max{0, x} to ensure strict convexity of G˜ θ. Here, wK is the output of an K-layer PICNN and is computed through forward propagation through the layers k = 0, . . . , K − 1 starting with the inputs v0 = y and w0 = x

Using properties for the composition of convex functions [18], it can be verified that the forward propagation in (3.6) defines a function that is convex in x, but not necessarily in y (which is not needed), as long as σ (w) is convex and non-decreasing.

To compute the log determinant of g −1, we use vectorized automatic differentiation to obtain the Hessian of G˜ θ with respect to its first input and then compute its eigenvalues. This is feasible when the Hessian is moderate; e.g., in our experiments, it is less than one hundred. We use efficient implementations of these methods that parallelize the computations over all the samples in a batch.

Our algorithm enforces the non-negativity constraint by projecting the parameters into the non-negative orthant after each optimization step using ReLU. Thereby, we alleviate the need for re- parameterization, for example, using the softplus function in [2]. Another novelty introduced in PICNN is that we utilize trainable affine layer parameters L (v) k and a context feature width u as a hyperparameter to increase the expressiveness of the conditioning variables, which are pivotal to characterizing conditional distributions; existing works such as [20] set L (v) 1:K−2 = I.

Sample generation. Due to our neural network parameterization, there is generally no closedform relation between g and g −1 = G˜ θ. As in [20], we approximate the inverse of G˜ θ during sampling as the Legendre-Fenchel dual. That is, we solve the convex optimization problem

Due to the strict convexity of G˜ θ in its first argument, the first-order optimality conditions gives

Hyperparameters. In our numerical experiments, we vary only three hyperparameters to adjust

the complexity of the architecture. As described in section 5 we randomly sample the depth, K,

the feature width, w, and the context width, u, from the values in Table 2.