A Fast Globally Linearly Convergent Algorithm for the Computation of Wasserstein Barycenters

A Fast Globally Linearly Convergent Algorithm for the Computation of Wasserstein Barycenters

Lei Yang111Institute of Operations Research and Analytics, National University of Singapore, 10 Lower Kent Ridge Road, Singapore 119076. (orayl@nus.edu.sg).   Jia Li222Department of Statistics, Pennsylvania State University, University Park, PA 16802, USA. (jiali@stat.psu.edu).   Defeng Sun333Department of Applied Mathematics, The Hong Kong Polytechnic University, Hung Hom, Kowloon, Hong Kong. (defeng.sun@polyu.edu.hk). The research of this author was supported in part by a start-up research grant from the Hong Kong Polytechnic University.   Kim-Chuan Toh444Department of Mathematics, and Institute of Operations Research and Analytics, National University of Singapore, 10 Lower Kent Ridge Road, Singapore 119076. (mattohkc@nus.edu.sg). The research of this author was supported in part by the Ministry of Education, Singapore, Academic Research Fund (Grant No. R-146-000-256-114).
Abstract

In this paper, we consider the problem of computing a Wasserstein barycenter for a set of discrete probability distributions with finite supports, which finds many applications in different areas such as statistics, machine learning and image processing. When the support points of the barycenter are pre-specified, this problem can be modeled as a linear programming (LP), while the problem size can be extremely large. To handle this large-scale LP, in this paper, we derive its dual problem, which is conceivably more tractable and can be reformulated as a well-structured convex problem with 3 kinds of block variables and a coupling linear equality constraint. We then adapt a symmetric Gauss-Seidel based alternating direction method of multipliers (sGS-ADMM) to solve the resulting dual problem and analyze its global convergence as well as its global linear convergence rate. We also show how all the subproblems involved can be solved exactly and efficiently. This makes our method suitable for computing a Wasserstein barycenter on a large dataset. In addition, our sGS-ADMM can be used as a subroutine in an alternating minimization method to compute a barycenter when its support points are not pre-specified. Numerical results on synthetic datasets and image datasets demonstrate that our method is more efficient for solving large-scale problems, comparing with two existing representative methods and the commercial software Gurobi.

Keywords:   Wasserstein barycenter; discrete probability distribution; semi-proximal ADMM; symmetric Gauss-Seidel.

1 Introduction

In this paper, we consider the problem of computing the mean of a set of discrete probability distributions under the Wasserstein distance (also known as the optimal transport distance or the earth mover’s distance). This mean, called the Wasserstein barycenter, is also a discrete probability distribution [1]. Recently, the Wasserstein barycenter has attracted much attention due to its promising performance in many application areas such as data analysis and statistics [5], machine learning [13, 23, 36, 37] and image processing [27]. For a set of discrete probability distributions with finite support points, a Wasserstein barycenter with its support points being pre-specified can be computed by solving a linear programming (LP) problem [2]. However, the problem size can be extremely large when the number of discrete distributions or the number of support points of each distribution is large. Thus, the classical LP methods such as the simplex method and the interior point method are no longer efficient or consume too much memory when solving this problem. This motivates the study of fast algorithms for the computation of Wasserstein barycenters; see, for example, [4, 6, 7, 11, 13, 33, 36, 37].

One representative approach is to introduce an entropy regularization in the LP and then apply some efficient first-order methods, e.g., the gradient descent method [13] and the iterative Bregman projection (IBP) method [4], to solve the regularized problem. These methods can be implemented efficiently and hence are suitable for large-scale datasets. However, they can only return an approximate solution of the LP and often encounter numerical issues when the regularization parameter becomes small. Another approach is to consider the LP as a constrained convex optimization problem with a separable structure and then apply some splitting methods to solve it. For example, the alternating direction method of multipliers (ADMM) was adapted in [36]. However, solving the quadratic programming subproblems involved is still highly expensive. Later, Ye at al. [37] developed a modified Bregman ADMM (BADMM) based on the original one [35] to solve the LP. In this method, all subproblems have closed-form solutions and hence can be solved efficiently. Promising numerical performance is also reported in [37]. However, this modified Bregman ADMM does not have a convergence guarantee so far.

In this paper, we also consider the LP as a constrained convex problem with multiple blocks of variables and develop an efficient method to solve its dual LP without introducing the entropy regularization to modify the objective function. Our method is actually a convergent 3-block ADMM that is designed based on recent progresses in research on convergent multi-block ADMM-type methods for solving convex composite conic programming; see [10, 25]. It is well known that the classical ADMM was originally proposed to solve a convex problem that contains 2 blocks of variables and a coupling linear equality constraint [16, 17]. The 2-block ADMM can be naturally extended to a multi-block ADMM for solving a convex problem with more than 2 blocks of variables. However, it has been shown in [8] that the directly extended ADMM may not converge when directly applied to a convex problem with 3 or more blocks of variables. This inspires many researchers to develop various convergent variants of the ADMM for convex problems with more than 2 blocks of variables; see, for example, [9, 10, 20, 25, 24, 32]. Among them, the Schur complement based convergent semi-proximal ADMM (sPADMM) was proposed by Li et al. [25] to solve a large class of linearly constrained convex problems with multiple blocks of variables, whose objective can be the sum of two proper closed convex functions and a finite number of convex quadratic or linear functions. This method essentially follows the original ADMM, but performs one more forward Gauss-Seidel sweep after updating the block of variables corresponding to the nonsmooth function in the objective. With this novel strategy, Li et al. [25] show that their method can be reformulated as a 2-block sPADMM with specially designed semi-proximal terms and its convergence is guaranteed from that of the 2-block sPADMM; see [15, Appendix B]. Later, this method was generalized to the inexact symmetric Gauss-Seidel based ADMM (sGS-ADMM) for more general convex problems [10, 26]. The numerical results reported in [10, 25, 26] also show that the sGS-ADMM always performs much better than the possibly non-convergent directly extended ADMM. In addition, as the sGS-ADMM is equivalent to a 2-block sPADMM with specially designed proximal terms, the linear convergence rate of the sGS-ADMM can also be derived based on the linear convergence rate of the 2-block sPADMM under some mild conditions; more details can be found in [19, Section 4.1].

Motivated by the above studies, in this paper, we adapt the sGS-ADMM to compute a Wasserstein barycenter by solving the dual problem of the original primal LP. The contributions of this paper are as follows:

  • We derive the dual problem of the original primal LP and characterize the properties of their optimal solutions; see Proposition 4.1. The resulting dual problem is our target problem, which is a linearly constrained convex problem containing 3 blocks of variables with a nice separable structure. We should emphasize again that we do not introduce any entropy regularization to modify the LP so as to make it computationally more tractable. This is in contrast to most existing works (e.g., [4, 6, 13, 33, 36, 37]) that primarily focus on (approximately) solving the original primal LP.

  • We apply the sGS-ADMM to solve the resulting dual problem and analyze its global convergence as well as its global linear convergence rate without any condition; see Theorems 4.1 and 4.2. We also show how all the subproblems in our method can be solved efficiently and that the subproblems at each step can be computed in parallel. This makes our sGS-ADMM highly suitable for computing Wasserstein barycenters on a large dataset.

  • We conduct rigorous numerical experiments on synthetic datasets and MNIST to evaluate the performance of our sGS-ADMM in comparison to existing state-of-the-art methods (IBP and BADMM) and the highly powerful commercial solver Gurobi. The computational results show that our sGS-ADMM performs much better than IBP and BADMM, and is also able to outperform Gurobi in solving large-scale LPs arising from Wasserstein barycenter problems.

The rest of this paper is organized as follows. In Section 2, we describe the basic problem of computing Wasserstein barycenters and derive its dual problem. In Section 3, we adapt the sGS-ADMM to solve the resulting dual problem and present the efficient implementations for each step that are crucial in making our method competitive. The convergence analysis of the sGS-ADMM is presented in Section 4. A simple extension to the free support case is discussed in Section 5. Finally, numerical results are presented in Section 6, with some concluding remarks given in Section 7.

Notation and Preliminaries.

In this paper, we present scalars, vectors and matrices in lower case letters, bold lower case letters and upper case letters, respectively. We use , , and to denote the set of real numbers, -dimensional real vectors, -dimensional real vectors with nonnegative entries and real matrices, respectively. For a vector , denotes its -th entry, denotes its Euclidean norm and denotes its weighted norm associated with the symmetric positive semidefinite matrix . For a matrix , denotes its -th entry, denotes its -th row, denotes its -th column, denotes its Fröbenius norm and denotes the vectorization of . We also use and to denote for all and for all . The identity matrix of size is denoted by . For any and , denotes the matrix obtained by horizontally concatenating and . For any and , denotes the matrix obtained by vertically concatenating and .

For an extended-real-valued function , we say that it is proper if for all and its domain is nonempty. A proper function is said to be closed if it is lower semicontinuous. For a proper closed convex function , its subdifferential at is and its conjugate function is defined by . For any and , it follows from [29, Theorem 23.5] that

(1.1)

For any , the proximal mapping of at is defined by . For a closed convex set , its indicator function is defined by if and otherwise.

In the following, a discrete probability distribution with finite support points is specified by , where are the support points or vectors and are the associated probabilities or weights satisfying and , .

2 Problem statement

In this section, we briefly recall the Wasserstein distance and describe the problem of computing a Wasserstein barycenter for a set of discrete probability distributions with finite support points. We refer interested readers to [34, Chapter 6] for more details on the Wasserstein distance and to [1, 2] for more details on the Wasserstein barycenter.

Given two discrete distributions and , the 2-Wasserstein distance between and is defined by , where is the optimal objective value of the following linear programming:

Then, given a set of discrete probability distributions with , a Wasserstein barycenter with support points ( is pre-specified empirically) is an optimal solution of the following problem

This is a two-stage optimization problem that can be easily shown to be equivalent to

(2.1)

where

  • (resp. ) denotes the (resp. ) dimensional vector with all entries being 1;

  • , ;

  • , for ;

  • , for .

Note that (2.1) is a nonconvex problem, where one needs to find the optimal support and the optimal weight vector of a barycenter simultaneously. However, in many real applications, the support of a barycenter can be specified empirically from the support points of . Indeed, in some cases, all distributions in have the same set of support points and hence the barycenter should also take the same set of support points. Thus, one only needs to find the weight vector of a barycenter. In view of this, from now on, we assume that the support is given. Consequently, problem (2.1) reduces to the following problem:

(2.2)

where denotes for simplicity. This is also the main problem studied in [4, 6, 7, 11, 13, 33, 36, 37] for the computation of Wasserstein barycenters. Moreover, one can easily see that (2.2) is indeed a large-scale LP containing variables with nonnegative constraints and equality constraints. For , and for all , the LP has about variables and equality constraints.

Remark 2.1 (Practical computational consideration when is sparse).

Note that any feasible point of (2.2) must satisfy and for any . This implies that if for some and , then for all , i.e., all entries in the -th column of are zeros. Based on this fact, one can verify the following statements.

  • For any optimal solution of (2.2), the point is also an optimal solution of the following problem

    (2.3)

    where denotes the support set of , i.e., , denotes the cardinality of , denotes the subvector of obtained by selecting the entries indexed by and denotes the submatrix of obtained by selecting the columns indexed by .

  • For any optimal solution of (2.3), the point obtained by setting and is also an optimal solution of (2.2), where .

Therefore, one can obtain an optimal solution of (2.2) by computing an optimal solution of (2.3). Note that the problem size of (2.3) can be much smaller than that of (2.2) when each is sparse, i.e., . Thus, solving (2.3) can reduce the computational cost and save memory in practice. Since (2.3) takes the same form as (2.2), we only consider (2.2) in the following.

For notational simplicity, let and be the indicator function over for each . Then, (2.2) can be equivalently written as

(2.4)

We next derive the dual problem of (2.4) (hence (2.2)). To this end, we write down the Lagrangian function associated with (2.4) as follows:

where , , are multipliers. Then, the dual problem of (2.4) is given by

(2.5)

Observe that

where is the Fenchel conjugate of . Thus, (2.5) is equivalent to

By introducing auxiliary variables , we can further reformulate the above problem as

(2.6)

Note that (2.6) can be viewed as a linearly constrained convex problem with 3 blocks of variables grouped as , and , whose objective is nonsmooth only with respect to and linear with respect to the other two. Thus, this problem exactly falls into the class of convex problems for which the sGS-ADMM is applicable; see [10, 25]. Then, it is natural to adapt the sGS-ADMM for solving (2.6), which is presented in the next section.

Remark 2.2 (2-block ADMM for solving (2.2)).

It is worth noting that one can also apply the 2-block ADMM to solve the primal problem (2.2) by introducing some proper auxiliary variables. For example, one can consider the following equivalent reformulation of (2.2):

where . Then, the 2-block ADMM can be readily applied with being one block and being the other one. This 2-block ADMM avoids solving the quadratic programming subproblems and hence is more efficient than the one used in [36]. However, it needs to compute the projection onto the -dimensional simplex times when solving the -subproblem in each iteration. This is still time-consuming when or becomes large. Thus, this 2-block ADMM is also not efficient enough for solving large-scale problems. In addition, we have adapted the 2-block ADMM for solving other reformulations of (2.2), but they all perform worse than our sGS-ADMM presented later. Hence, we will no longer consider ADMM-type methods for solving the primal problem (2.2) or its equivalent variants in this paper.

3 sGS-ADMM for computing Wasserstein barycenters

In this section, we present the sGS-ADMM for solving (2.6). First, we write down the augmented Lagrangian function associated with (2.6) as follows:

where , , are multipliers and is the penalty parameter. The sGS-ADMM for solving (2.6) is presented in Algorithm 1.

Input: the penalty parameter , the dual step-size and the initialization , , , , , , . Set .
while a termination criterion is not met, do

  • 1. Compute .

  • 2a. Compute .

  • 2b. Compute .

  • 2c. Compute .

  • 3. Compute

end while
Output: , , , , , .

Algorithm 1 sGS-ADMM for solving (2.6)

Comparing with the directly extended ADMM, our sGS-ADMM in Algorithm 1 just has one more update of in Step 2a. This step is actually the key to guarantee the convergence of the algorithm. In the next section, we shall see that computed from Step 2a–2c can be exactly obtained by minimizing plus a special proximal term with respect to . Moreover, all subproblems in Algorithm 1 can be solved efficiently (in fact analytically) and the subproblems in each step can also be computed in parallel. This makes our method highly suitable for solving large-scale problems.

The reader may have observed that instead of computing and sequentially as in Step 2a–2c, one can also compute simultaneously in one step by solving a huge linear system of equations of dimension . Unfortunately, for the latter approach, the computation of the solution would require the Cholesky factorization of a huge coefficient matrix, and this approach is not practically viable. In contrast, for our approach in Step 2a-2c, we shall see shortly that the solutions can be computed analytically without the need to perform Cholesky factorizations of large coefficient matrices. This also explains why we have designed the computations as in Step 2a-2c.

Before ending this section, we present the computational details and the efficient implementations in each step of Algorithm 1.

  • 1. Note that is actually separable with respect to and hence one can compute independently. Specifically, is obtained by solving

    Thus, we have

    where the last equality follows from the Moreau decomposition (see [3, Theorem 14.3(ii)]), i.e., for any and the proximal mapping of can be computed efficiently by the algorithm proposed in [12] with the complexity of that is typically observed in practice. Moreover, for each , is obtained by solving

    Then, it is easy to see that

    where . Note that is already computed for updating in the previous iteration and thus it can be reused in the current iteration. The computational complexity in this step is . We should emphasize that because the matrices such as , are very large, even performing simple operations such as adding two such matrices can be time consuming. Thus we have paid special attention to design the computations in each step of the sGS-ADMM so that matrices computed in one step can be reused for the next step.

  • 2a. Similarly, is separable with respect to and then one can also compute , , independently. For each , is obtained by solving

    It is easy to prove that

    where . Note that has already been computed in Step 1 and hence can be computed by just a simple operation. We note that is computed analytically for all . The computational complexity in this step is .

  • 2b. In this step, one can see that are coupled in (due to the term ) and hence the problem of minimizing with respect to cannot be reduced to separable subproblems. However, one can still compute them efficiently based on the following observation. Note that is obtained by solving

    The gradient of with respect to is

    where and . It follows from the optimality conditions, namely, that

    (3.1)

    By dividing in (3.1) for , adding all resulting equations and doing some simple algebraic manipulations, one can obtain that

    Then, using this and (3.1), we have

    Observe that we can compute analytically for . In the above computations, one can first compute in parallel for to obtain . Then, can be computed in parallel for . Observe that by using the updating formula for in Step 2a, we have that