Unbalanced Sobolev Descent

Unbalanced Sobolev Descent

Abstract

We introduce Unbalanced Sobolev Descent (USD), a particle descent algorithm for transporting a high dimensional source distribution to a target distribution that does not necessarily have the same mass. We define the Sobolev-Fisher discrepancy between distributions and show that it relates to advection-reaction transport equations and the Wasserstein-Fisher-Rao metric between distributions. USD transports particles along gradient flows of the witness function of the Sobolev-Fisher discrepancy (advection step) and reweighs the mass of particles with respect to this witness function (reaction step). The reaction step can be thought of as a birth-death process of the particles with rate of growth proportional to the witness function. When the Sobolev-Fisher witness function is estimated in a Reproducing Kernel Hilbert Space (RKHS), under mild assumptions we show that USD converges asymptotically (in the limit of infinite particles) to the target distribution in the Maximum Mean Discrepancy (MMD) sense. We then give two methods to estimate the Sobolev-Fisher witness with neural networks, resulting in two Neural USD algorithms. The first one implements the reaction step with mirror descent on the weights, while the second implements it through a birth-death process of particles. We show on synthetic examples that USD transports distributions with or without conservation of mass faster than previous particle descent algorithms, and finally demonstrate its use for molecular biology analyses where our method is naturally suited to match developmental stages of populations of differentiating cells based on their single-cell RNA sequencing profile. Code is available at http://github.com/ibm/usd.

1 Introduction

Particle flows such as Stein Variational Gradient descent (Liu and Wang, 2016), Sobolev descent (Mroueh et al., 2019) and MMD flows (Arbel et al., 2019), allow the transport of a source distribution to a target distribution, following paths that progressively decrease a discrepancy between distributions (Kernel Stein discrepancy and MMD, respectively). Particle flows can be seen through the lens of Optimal Transport as gradient flows in the Wasserstein geometry (Santambrogio, 2017), and they’ve been recently used to analyze the dynamics of gradient descent in over-parametrized neural networks in Chizat and Bach (2018) and of Generative Adversarial Networks (GANs) training (Mroueh et al., 2019).

Unbalanced Optimal Tansport (Chizat et al., 2015, 2018a, 2018b; Séjourné et al., 2019) is a new twist on the classical Optimal Transport theory (Villani, 2008), where the total mass between source and target distributions may not be conserved. The Wasserstein Fisher-Rao (WFR) distance introduced in Chizat et al. (2018a) gives a dynamic formulation similar to the so-called Benamou-Brenier dynamic form of the Wasserstein- distance (Benamou and Brenier, 2000), where the dynamics of the transport is governed by an advection term with a velocity field and a reaction term with a rate of growth , corresponding to the construction and destruction of mass with the same rate:

(1)

From a particle flow point of view, this advection-reaction in Unbalanced Optimal Transport corresponds to processes of birth and death, where particles are created or killed in the transport from source to target. Particle gradient descent using the WFR geometry have been used in the analysis of over-parameterized neural networks and implemented as Birth-Death processes in Rotskoff et al. (2019) and as conic descent in Chizat (2019). In the context of particles transportations, Lu et al. (2019) showed that birth and death processes can accelerate the Langevin diffusion. On the application side, Unbalanced Optimal Transport is a powerful tool in biological modeling. For instance, the trajectories of a tumor growth have been modeled in the WFR framework by Chizat and Marino (2017). Schiebinger et al. (2019) and Yang and Uhler (2019) used Unbalanced Optimal Transport to find differentiation trajectories of cells during development.

The dynamic formulation of WFR is challenging as it requires solving PDEs. One can use the unbalanced Sinkhorn divergence and apply an Euler scheme to find the trajectories between source and target as done in Feydy et al. (2018) but this does not give any convergence guarantees.

In this paper we take another approach similar to the one of Sobolev Descent (Mroueh et al., 2019). We introduce the Kernel Sobolev-Fisher discrepancy that is related to WFR and has the advantage of having a closed form solution. We present a particle descent algorithm in the unbalanced case named Unbalanced Sobolev Descent (USD) that consists of two steps: an advection step that uses the gradient flows of a witness function of the Sobolev-Fisher discrepancy, and a reaction step that reweighs the particles according to the witness function. We show theoretically that USD is convergent in the Maximum Mean Discrepancy sense (MMD), that the reaction step accelerates the convergence, in the sense that it results in strictly steeper descent directions, and give a variant where the witness function is efficiently estimated as a neural network. We then empirically demonstrate the effectiveness and acceleration of USD in synthetic experiments, image color transfer tasks, and finally use it to model the developmental trajectories of populations of cells from single-cell RNA sequencing data (Schiebinger et al., 2019).

2 Sobolev-Fisher Discrepancy

In this Section we define the Sobolev-Fisher Discrepancy (SF) and show how it relates to advection-reaction PDEs. While this formulation remains computationally challenging, we’ll show in Section 3 how to approximate it in RKHS.

2.1 Advection-Reaction with no Conservation of Mass

Definition 1 (Sobolev-Fisher Discrepancy).

Let be two measures defined on . For , the Sobolev-Fisher Discrepancy is defined as follows:

.

Note that the objective of SF is an Integral Probability Metric (IPM) objective, and the function space imposes a constraint on the weighted Sobolev norm of the witness function on the support of the distribution . We refer to as the source distribution and as the target distribution. The following theorem relates the solution of the Sobolev-Fisher Discrepancy to an advection-reaction PDE:

Theorem 1 (Sobolev-Fisher Critic as Solution of an Advection-Reaction PDE).

Let be the solution of the advection-reaction PDE:

Then , with witness function .

From Theorem 1 we see that the witness function of solves an advection-reaction where the mass is transported from to , via an advection term following the gradient flow of , and a reaction term amounting to construction/destruction of mass that we also refer to as a birth-death process with a rate given by . Intuitively, if the witness function we need to create mass, and destruct mass if . This is similar to the notion of particle birth and death defined in Rotskoff et al. (2019) and Lu et al. (2019).

In Proposition 1 we give a convenient unconstrained equivalent form for :

Proposition 1 (Unconstrained Form of ).

SF satisfies the expression: , with

Theorem 2 gives a physical interpretation for as finding the witness function that has minimum sum of kinetic energy and rate of birth-death while transporting to via advection-reaction:

Theorem 2 (Kinetic Energy & Birth-Death rates minimization).

Consider the following minimization:

We then have that , and moreover:

Remarks.

a) When we obtain the Sobolev Discrepancy, or , that linearizes the Wasserstein- distance. b) Note that this corresponds to a Beckman type of optimal transport (Peyré and Cuturi, 2017), where we transport to ( and do not have the same total mass) via an advection-reaction with mass not conserved. It is easy to see that

2.2 Advection-Reaction with Conservation of Mass

Define the Sobolev-Fisher Discrepancy with conservation of mass: where The only difference between the previous expression and in Proposition 1 is that the variance of the witness function is kept under control, instead of the second order moment. Defining

one can similarly show that has the primal representation:

Hence, we see that is the minimum sum of kinetic energy and variance of birth-death rate for transporting to following an advection-reaction PDE with conserved total mass. The conservation of mass comes from the fact that satisfies:

3 Kernel Sobolev-Fisher Discrepancy

In this section we turn to the estimation of SF discrepancy by restricting the witness function to a Reproducing Kernel Hilbert Space (RKHS), resulting in a closed-form solution.

3.1 Estimation in Finite Dimensional RKHS

Consider the finite dimensional RKHS, corresponding to an dimensional feature map :

Define the kernel mean embeddings , and . Let be the covariance matrix and be the Gramian of the Jacobian, where , .

Definition 2 (Regularized Kernel Sobolev-Fisher Discrepancy (KSFD)).

Let , and let and , define: The Regularized Kernel Sobolev-Fisher Discrepancy is defined as:

When this corresponds to the unbalanced case, i.e. birth-death with no conservation of total mass, while for we have birth-death with conservation of total mass.

Proposition 2 (Estimation in RKHS).

The Kernel Sobolev-Fisher Discrepancy is given by: where the critic with Let and , then:

Remarks.

a) For the unbalanced case , we refer to as . For the case of mass conservation , refer to as . Note that . b) A similar Kernelized discrepancy was introduced in Arbel et al. (2018), but not as an approximation of the Sobolev-Fisher discrepancy, nor in the context of unbalanced distributions and advection-reaction. c) For we obtain the kernelized Sobolev Discrepancy KSD of Mroueh et al. (2019).

3.2 Kernel SF for Direct Measures

Consider direct measures and (with no conservation of mass we can have ). An estimate of the Sobolev-Fisher critic is given by , where the empirical Kernel Mean Embeddings are and . The empirical operator embeddings are given by , and

4 Unbalanced Continuous Kernel Sobolev Descent

Given the Kernel Sobolev-Fisher Discrepancy defined in the previous sections and its relation to advection-reaction transport, in this section we construct a Markov process that transports particles drawn from a source distribution to a target distribution. Note that we don’t assume that the densities are normalized nor have same total mass.

4.1 Constructing the Continuous Markov Process

Given and weighted particles drawn from the source distribution : i.e and . Recall that the target distribution is given by . We define the following Markov Process that we name Unbalanced Kernel Sobolev Descent:

(2)

where is the critic of the Kernel Sobolev-Fisher discrepancy, whose expression and gradients are given in Proposition 2. We see that USD consists of two steps: the advection step that updates the particles positions following the gradient flow of the Sobolev-Fisher critic, and a reaction step that updates the weights of the particles with a growth rate proportional to that critic. This reaction step consists in mass construction or destruction, that depends on the confidence of the witness function. This can be seen as birth-death process on the particles, where the survival probability of a particle is proportional to the critic evaluation on this particle.

4.2 Generator Expression and PDE in the limit of

Proposition 3 gives the evolution equation of a functional of the intermediate distributions produced in the descent, at the limit of infinite particles :

Proposition 3.

Let , be a functional on the probability space. Let be the distribution produced by USD at time . Let be its limit as , we have:

where Where the functional derivative is defined through first variation for a signed measure :

In particular, the paths of USD in the limit of satisfy the advection-reaction equation:

4.3 Unbalanced Sobolev Descent decreases the MMD.

The following Theorem shows that USD when the number of the particles goes to infinity decreases the MMD distance at each step, where:

Theorem 3 (Unbalanced Sobolev Descent decreases the MMD).

Consider the paths produced by USD. In the limit of particles we have

(3)

In particular, in the regularized case with strict descent (i.e.  implies ), USD converges in the MMD sense: . Similarly to Mroueh et al. (2019), strict descent is ensured if the kernel and the target distribution satisfy the condition:

USD Accelerates the Convergence. We now prove a Lemma the can be used to show that Unbalanced Sobolev Descent has an acceleration advantage over Sobolev Descent (Mroueh et al., 2019).

Lemma 1.

In the regularized case with , the Kernel Sobolev-Fisher Discrepancy is strictly upper bounded by the Kernel Sobolev discrepancy (i.e for ) (Mroueh et al., 2019):

From Lemma 1 and Eq. (3), we see that USD (), results in a larger decrease in MMD than SD Mroueh et al. (2019) (), resulting in a steeper descent. Hence, USD advantages over SD are twofold: 1) it allows unbalanced transport, 2) it accelerates convergence for the balanced and unbalanced transport.

USD with Universal Infinite Dimensional Kernel. While we presented USD with a finite dimensional kernel for ease of presentation, we show in Appendix D that all our results hold for an infinite dimensional kernel. For a universal or a characteristic kernel, convergence in MMD implies convergence in distribution (see (Simon-Gabriel and Schölkopf, 2016, Theorem 12)). Hence, using a universal kernel, USD guarantees the weak convergence as .

4.4 Understanding the effect of the Reaction Step: Whitened Principal Transport Directions

In Mroueh et al. (2019) it was shown that the gradient of the Sobolev Discrepancy can be written as a linear combination of principal transport directions of the Gramian of derivatives . Here we show that unbalanced descent leads to a similar interpretation in a whitened feature space thanks to the regularizer. Let , , and let It is easy to see that the critic of the SF can be written as: . Note that is a whitened feature map and is the Gramian of its derivatives. Let be the eigenvectors and eigenvalues of . We have: . Hence, we write the gradient of the Sobolev-Fisher critic as where . This says that the mass is transported along a weighted combination of whitened principal transport directions . introduces a damping of the transport as it acts as a spectral filter on the transport directions in the whitened space.

5 Discrete time Unbalanced Kernel and Neural Sobolev Descent

In order to get a practical algorithm in this Section we discretize the continuous USD given in Eq. (2). We also give an implementation parameterizing the critic as a Neural Network.

Discrete Time Kernel USD. Recall that the source distribution , note and . The target distribution , and assume for simplicity . Let , for , for , we discretize the advection step:

Let For , similarly we discretize the reaction step as:

If (total mass not conserved) we define the reweighing as follows: and if (mass conserved): and finally : .

Neural Unbalanced Sobolev Descent. Motivated by the use of neural network critics in Sobolev Descent (Mroueh et al., 2019), we propose a Neural variant of USD by parameterizing the critic of the Sobolev-Fisher Discrepancy as a Neural network trained via gradient descent with the Augmented Lagrangian Method (ALM) on the loss function of SF given in Definition 1. The re-weighting is defined as in the kernel case above. Neural USD with re-weighting is summarized in Algorithm 1 in Appendix B. Note that the re-weighting can also be implemented via a birth-death process as in Rotskoff et al. (2019). In this variant, particles are duplicated or killed with a probability driven by the growth rate given by the critic. We give the details of the implementation as birth-death process in Algorithm 2 (Appendix B).

Computational and Sample Complexities. The computational complexity Neural USD is given by that of updating the witness function and particles by SGD with backprop, i.e. , where is the mini-batch size, is the training time, is the gradient computation time for particles update. corresponds to a forward and a backward pass through the critic and its gradient. The sample complexity for estimating the Sobolev Fisher critic scales like similar to MMD Gretton et al. (2012).

6 Relation to Previous Work

Table 1 in Appendix A summarizes the main differences between Sobolev descent (Mroueh et al., 2019), which only implements advection, and USD that also implements advection-reaction. Our work is related to the conic particle descent that appeared in Chizat (2019) and Rotskoff et al. (2019). The main difference of our approach is that it is not based on the flow of a fixed functional, but we rather learn dynamically the flow that corresponds to the witness function of the Sobolev-Fisher discrepancy. The accelerated Langevin Sampling of Lu et al. (2019) also uses similar principles in the transport of distributions via Langevin diffusion and a reaction term implemented as a birth-death process. The main difference with our work is that in Langevin sampling the log likelihood of the target distribution is required explicitly, while in USD we only need access to samples from the target distribution. USD relates to unbalanced optimal transport (Chizat et al., 2015, 2018a, 2018b; Séjourné et al., 2019) and offers a computational flexibility when compared to Sinkhorn approaches Chizat et al. (2018b); Séjourné et al. (2019), since it scales linearly in the number of points while Sinkhorn is quadratic. Compared to WFR (Eq. (1)), USD finds greedily the connecting path, while WFR solves an optimal planning problem.

7 Applications

We experiment with USD on synthetic data, image coloring and prediction of developmental stages of scRNA-seq data. In all our experiments we report the MMD distance with a gaussian kernel, computed using the random Fourier features (RF) approximation Rahimi and Recht (2007) with RF and kernel bandwith equal to (the input dimension). We consider the conservation of mass case, i.e. .

Synthetic Examples.

We test Neural USD descent (Algorithms 1 and 2) on two synthetic examples. In the first example (Figure 1), the source samples are drawn from a 2D standard Gaussian, while target samples are drawn from a Mixture of Gaussians (MOG). Samples from this MOG have uniform weights. In the second example (Figure 2), source samples are drawn from a ‘cat’-shaped density whereas the target samples are drawn uniformly from a ‘heart’. Samples from the targets have non-uniform weights following a horizontal gradient. In order to target such complex densities USD exploits advection and reaction by following the critic gradients and by creation and destruction of mass. We see in Figs 1 and 2 a faster mixing of USD in both, implementation with weights (w) and as birth-death (bd) processes compared to the Sobolev descent algorithm of Mroueh et al. (2019).

(a) Neural USD paths in transporting a Gaussian to a MOG. We compare Sobolev descent (SD, Mroueh et al. (2019)) to both USD implementations: with birth-death process (bd: Algorithm 2) and weights (w: Algorithm 1). USD outperforms SD in capturing the modes of the MOG.
(b) MMD as a function of step along the descent from a Gaussian to a MOG. Both USD implementations convergence faster to the target distribution, reaching lower MMD than Sobolev Descent that relies on advection only.
Figure 1: Neural USD transport of a Gaussian to a MOG (target distribution is uniformly weighted).
(a) Neural USD transporting a ‘cat’ distributed cloud to a ‘heart’. The main difference with the example above is that the points of the target distribution have non uniform weights describing a linear gradient as seen from the color code in the figure. Similarly to the MOG case, USD outperforms SD and better captures the non uniform density of the target.
(b) MMD as function of step along the descent from cat heart Grad. Similarly to the uniform target case USD accelerates the descent and outperforms SD.
Figure 2: Neural USD transport of a ‘cat’ to a non-uniform ‘heart’. Samples from the target distribution have non-uniform weights given by ’s following a linearly decaying gradient.

Image Color Transfer.

We test Neural USD on the image color transfer task. We choose target images that have sparse color distributions. This is a good test for unbalanced transport since intuitively having birth and death of particles accelerates the transport convergence in this case. We compare USD to standard optimal transport algorithms. We follow the recipe of Ferradans et al. (2013) as implemented in the POT library (Flamary and Courty, 2017), where images are subsampled for computational feasibility and then interpolated for out-of-sample points. We compare USD to Earth-Moving Distance (EMD), Sinkhorn (Cuturi, 2013) and Unbalanced Sinkhorn (Chizat et al., 2018b) baselines. We see in Figure 3 that USD achieves smaller MMD to the target color distribution. We give in Appendix H.2 in Fig 7 trajectories of the USD.

Figure 3: Color Transfer with USD using (bd) Algorithm 2. Comparison to OT baselines (EMD, Sinkhorn and Unbalanced Sinkhorn). USD achieves lower MMD, and faithfully captures the sparse distribution of the target.

Developmental Trajectories of Single Cells.

Figure 4: Mean and standard deviations plots of Normalized MMD and EMD for the intermediate stage prediction by USD and WOT (unbalanced OT) of Schiebinger et al. (2019) (means and standards deviation are computed over intervals). While USD outperforms WOT in MMD, the reverse holds in EMD. See text for an explanation.

When the goal is not only to transport particles but also to find intermediate points along trajectories, USD becomes particularly interesting. This type of use case has recently received increased attention in developmental biology, thanks to single-cell RNA sequencing (scRNA-seq), a technique that records the expression profile of a whole population of cells at a given stage, but does so destructively. In order to trace the development of cells in-between such destructive measurements, Schiebinger et al. (2019) proposed to use unbalanced optimal transport (Chizat et al., 2018b). Denoting those populations (source) and (target), then, in order to predict the population at an intermediate time , Schiebinger et al. (2019) used a linear interpolation between matches between the source and target populations based on the coupling of unbalanced OT. This type of interpolation is a form of McCann interpolate McCann (1997). As an alternative, we propose to use the mid-point of the USD descent as an interpolate, i.e. the timestamp in the descent such that . We test this procedure on the dataset released by Schiebinger et al. (2019). For all time intervals in the dataset, we compute the intermediate stage . We compare the quality of this interpolate with that obtained by the WOT algorithm of Schiebinger et al. (2019) in terms of MMD to the ground truth intermediate population , normalized by MMD between initial and final population, i.e. . Fig. 4 gives mean and standard deviation of the normalized MMD between intermediate stages predicted by USD and the ground truth. Note that mean and standard deviations are computed across time intervals, individual MMDs can be found in Figure 8 in Appendix H. From Figure 4 we see that USD outperforms WOT in MMD, since USD is designed to decrease the MMD distance. On the other hand, for fairness of the evaluation we also report Normalized EMD (Earth-Mover Distance, normalized similarly) for which WOT outperforms USD. This is not surprising since WOT relies on unbalance OT, while USD instead provides guarantees in terms of MMD.

8 Conclusion

In this paper we introduced the KSFD discrepancy and showed how it relates to an advection-reaction transport. Using the critic of KSFD, we introduced Unbalanced Sobolev Descent (USD) that consists in an advection step that moves particles and a reaction step that re-weights their mass. The reaction step can be seen as birth-death process which, as we show theoretically, speeds up the descent compared to previous particle descent algorithms. We showed that the MMD convergence of Kernel USD and presented two neural implementations of USD, using weight updates, and birth and death of particle, respectively. We empirically demonstrated on synthetic examples and in image color transfer, that USD can be reliably used in transporting distributions, and indeed does so with accelerated convergence, supporting our theoretical analysis. As a further demonstration of our algorithm, we showed that USD can be used to predict developmental trajectories of single cells based on their RNA expression profile. This task is representative of a situation where distributions of different mass need to be compared and interpolated between, since the different scRNA-seq measurements are taken on cell populations of dissimilar size at different developmental stages. USD can naturally deal with this unbalanced setting. Finally we compared USD to unbalanced OT algorithms, showing its viability as a data-driven, more scalable dynamic transport method.

Broader Impact Statement

Our work provides a practical particle descent algorithm that comes with a formal convergence proof and theoretically guaranteed acceleration over previous competing algorithms. Moreover, our algorithm can naturally handle situations where the objects of the descent are particles sampled from a source distribution descending towards a target distribution with different mass.

The type of applications that this enables range from theoretically principled modeling of biological growths processes (like tumor growth) and developmental processes (like the differentiation of cells in their gene expression space), to faster numerical simulation of advection-reaction systems.

Since our advance is mainly theoretical and algorithmic (besides the empirical demonstrations), its implications are necessarily tied to the utilization for which it is being deployed. Beside the applications that we mentioned, particle descent algorithms like ours have been proposed as a paradigm to characterize and study the dynamics of Generative Adversarial Network (GANs) training. As such, they could indirectly contribute to the risks associated with the nefarious uses of GANs such as deepfakes. On the other hand, by providing a tools to possibly analyze and better understand GANs, our theoretical results might serve as the basis for mitigating their abuse.

Supplementary Material: Unbalanced Sobolev Descent

.1 Relation to Unbalanced Optimal Transport

We now relate our definition of the Sobolev-Fisher discrepancy to the following norm. For a signed measure define

It can be shown that .

The dynamic formulation of the Wasserstein Fisher-Rao metric given in Equation (1) can therefore be compactly written as:

(4)

From this connection to WFR through , we see the link of the Sobolev-Fisher discrepancy to unbalanced optimal transport, since it linearizes the WFR for small perturbations.

Appendix A Summary Table

Markov Process
Particles
PDE (As )
Guarantee
Sobolev Descent
Flow of
Target:
Source :
N/A
Principal Transport Directions:
Advection
Unbalanced Sobolev
Descent: Flow of
Target:
Source :
Whitened Principal Transport Directions :
Advection/Reaction
(Mass not conserved )
Balanced Sobolev
Descent: Flow of
Target:
Source :
Whitened Principal Transport Directions :
Advection/Reaction
(Mass conserved )
Table 1: Summary table comparing Unbalanced Sobolev Descent to Sobolev Descent.

Appendix B Algorithms

  Inputs: Learning rate particles, number of critics updates, number of iterations, , drawn from target distribution drawn from source distribution Neural critic , parameters of the neural network
  Initialize for
  for   do
     Critic Parameters Update
     (between particles updates, gradient descent on the critic is initialized from previous episodes)
      Critic Update(, target , current source , ) (Given in Alg. 3 in Appendix B)
     Particles and Weights Update
     for  to  do
         (current is the critic between and , advection step)
         (reaction step)
        if  (mass conservation)  then
           
        else if  (mass not conserved) then
           
        end if
     end for
  end for
  Output:
Algorithm 1 w-Neural Unbalanced Sobolev Descent (weighted version – ALM Algorithm)
  Inputs: Same inputs of Algorithm 1
  Initialize for
  for   do
     Critic Parameters Update
     (between particles updates gradient descent on the critic is initialized from previous episodes)
      Critic Update(, target , current source , ) (Given in Alg. 3 in App. B)
     Particles and Weights Update (birth-death)
     for  to  do
         (current is the critic between and )
        
        if   then
           Duplicate with probability
        else if   then
           kill with probability
        end if
     end for{Make population size again}
      number of particles at the end of the loop
     if  then
        Kill randomly selected particles
     else if  then
        Duplicate randomly selected partciles
     end if
  end for
  Output:
Algorithm 2 bd-Neural Unbalanced Sobolev Descent (Birth-Death – ALM Algorithm)
  for  to  do
     
     
     
     
     
     
      {SGD rule on with learning rate }
  end for
  Output:
Algorithm 3 Critic Update(, target , current source