Compressive Learning of Generative Networks
Generative networks implicitly approximate complex densities from their sampling with impressive accuracy. However, because of the enormous scale of modern datasets, this training process is often computationally expensive. We cast generative network training into the recent framework of compressive learning: we reduce the computational burden of large-scale datasets by first harshly compressing them in a single pass as a single sketch vector. We then propose a cost function, which approximates the Maximum Mean Discrepancy metric, but requires only this sketch, which makes it time- and memory-efficient to optimize.
These last few years, data-driven methods took over the state-of-the-art in a staggering amount of research and engineering applications. This success owes to a combination of two factors: machine learning models that combine expressive power and good generalization properties (e.g., deep neural networks), and unprecedented availability of training data in enormous quantities.
Among such models, generative networks (GNs) received a significant amount of interest for their ability to embed data-driven priors in general applications, e.g., for solving inverse problems such as super-resolution, deconvolution, inpainting, or compressive sensing to name a few [bora2017compressed, mardani2017deep, rick2017one, Lucas_2018]. As explained in Sec. 2, GNs are deep neural networks (DNNs) trained to generate samples that mimic those available in a given dataset. By minimizing some well-crafted cost-function at the training, these networks implicitly learn the probability distribution synthesizing this dataset; passing randomly generated low-dimensional inputs through to the GN then generates new high-dimensional samples.
In generative adversarial networks (GANs) this cost is dictated by a discriminator network that classifies real (training) and fake (generated) examples, the generative and the discriminator networks being learned simultaneously in a two-player zero-sum game [goodfellow2014generative]. While GANs are the golden standard, achieving the state-of-the-art for a wide variety of tasks, they are notoriously hard to learn due to the need to balance carefully the training of the two networks.
MMD-GNs minimize the simpler Maximum Mean Discrepancy (MMD) cost function [li2015generative, dziugaite2015training], i.e., a “kernelized” distance measuring the similarity of generated and real samples. Although training MMD-GNs is conceptually simpler than GANs — we can resort to simple gradient descent-based solvers (e.g., SGD) — its computational complexity scales poorly with large-scale datasets: each iteration necessitates numerous (typically of the order of thousands) accesses to the whole dataset. This severely limits the practical use of MMD-GNs [arjovsky2017wasserstein].
Indeed, modern machine learning models such as GN are typically learned from numerous (e.g., several million) training examples. Aggregating, storing, and learning from such large-scale datasets is a serious challenge, as the required communication, memory, and processing resources inflate accordingly. In compressive learning (CL), larger datasets can be exploited without demanding more computational resources. The data is first harshly compressed to a single vector called the sketch, a process done in a single, easily parallelizable pass over the dataset [gribonval2017compressiveStatisticalLearning]. The actual learning is then performed from the sketch only, which acts as a light proxy for the whole dataset statistics. However, CL has for now been limited to “simple” models explicitly parametrized by a handful of parameters, such as k-means clustering, Gaussian mixture modeling or PCA [gribonval2017compressiveStatisticalLearning].
This work proposes and assesses the potential of sketching to “compressively learn” deep generative networks (MMD-GNs) with greatly reduced computational cost (see Fig. 1). By defining a cost function and practical learning scheme, our approach serves as a prototype for compressively learning general generative models from sketches. The effectiveness of this scheme is tested on toy examples.
2 Background, related work and notations
To fix the ideas, given some space , we assimilate any dataset with samples to a discrete probability measure , i.e., an empirical estimate for the probability distribution generating . Said differently, and , where is the Dirac measure at .
2.1. Compressive Learning:
In CL, massive datasets are first efficiently (in one parallelizable pass) compressed into a single sketch vector of moderate size. The required parameters are then extracted from this sketch, using limited computational resources compared to usual algorithms that operate on the full dataset [gribonval2017compressiveStatisticalLearning].
The sketch operator realizes an embedding of any (infinite dimensional) probability measure into the low-dimensional domain . This sketching amounts to taking the expectation of the random Fourier features (RFF) [Rahimi2008RFF] of , with . For large values of , we expect that
where is the sketch of the dataset . This sketch, which has a constant size whatever the cardinality of , thus embeds by empirically averaging (or pooling) all RFF vectors . We still need to specify the RFF projection matrix ; it is randomly generated by drawing “frequencies” . In other words, corresponds here to a random sampling (according to the law ) of the characteristic function of (i.e., its Fourier transform). By Bochner’s theorem [Rudin1962bochnerBook], is related to some shift-invariant kernel by the (inverse) Fourier transform: .
CL aims at learning, from only the sketch , an approximation for the density , parametrized by . For example, collects the position of the centroids for compressive -means, and the weights, centers and covariances of different Gaussians for compressive Gaussian mixtures fitting. This is achieved by solving the following density fitting (“sketch matching”) problem:
For large values of , the cost in (2) estimates a metric between and , called the Maximum Mean Discrepancy (MMD) [gretton2012kernelTwoSample], that is kernelized by , i.e., writing , the MMD reads
Using Bochner’s theorem, we can indeed rewrite (3) as
Provided is supported on , if and only if [Sriperumbudur2010hilbertEmbedding]. Thus, minimizing (2) accurately estimates from if is large compared to the complexity of the model; e.g., in compressive K-means, CL requires experimentally to learn the centroids of clusters in .
The non-convex sketch matching problem (2) is generally solved with greedy heuristics (e.g., CL-OMPR [keriven2016GMMestimation]). As they require a closed-form expression of and the Jacobian , CL has so far be limited to cases where is explicitly available and easy to manipulate.
2.2. Generative networks:
To generates realistic data samples, a GN (i.e., a DNN) with weights is trained as follows. Given , we compute the empirical distribution of inputs randomly drawn in a low-dimensional latent space from a simple distribution , e.g., . By design, is related to sampling the pushforward distribution of by . The parameter is then set such that . While several divergences have been proposed to quantify this objective, we focus here on minimizing the MMD metric [li2015generative, dziugaite2015training]. Using (3) and discarding constant terms, we get the MMD-GNs fitting problem:
Li et al. called this approach generative moment matching networks, as minimizing (3) amounts to matching all the (infinite) moments of and thanks to the space kernelization yielded by [Hall2005generalizedMethodMoments] (see Fig. 1).
If is differentiable, gradient descent-based methods can be used to solve (5), using back-propagation to compute the gradients of . However, for true samples and generated samples (or a batch-size), each evaluation of (and its gradient) requires computations. Training MMD-GNs, while conceptually simpler than training GANs, is much slower due to all the pairwise evaluations of the kernel required at each iteration — especially for modern large-size datasets.
3 Compressive Learning of Generative Networks
In this work, given a dataset , we propose to learn a generative network using only the sketch defined in (1) (see Fig. 1). For this, given samples , we solve a generative network sketch matching problem that selects with
From (4), we reach for large values of , as established from the link relating and . Compared to the exact MMD in (5), is, however, much easier to optimize. Once the dataset sketch has been pre-computed (in one single pass over , possibly in parallel), we only need to compute (i.e., by computing contributions by feed-forward, before averaging them) to compute the Euclidean distance between both quantities. In short, we access only once then discard it, and evaluating the cost has complexity , i.e., much smaller than , the complexity of the exact MMD (5) (see Sec. 2.2).
Equally importantly, the gradient is easily computed. With the residual and its conjugate transpose,
Above, is the Jacobian matrix listing the partial derivatives of the sketch entries with respect to the dimension of , which is evaluated at the generated samples . The last term is computed by the back-propagation algorithm as it contains the derivative of the network output (for fixed) with respect to the parameters . Algorithmically, the feature function amounts to an extra layer on top of the GN, with fixed weights and activation . We can then plug those expressions in any gradient-based optimisation solver
We conclude this section by an interesting interpretation of (6). While CL requires closed form expressions for and , our GN formalism actually estimates those quantities by Monte-Carlo sampling, i.e., replacing by . This thus opens CL to non-parametric density fitting.
For this preliminary work, we visually illustrate the effectiveness of minimizing (6) by considering three 2-D synthetic datasets made of samples (see the top row of Fig. 2): (i) a 2-D spiral , with and , (ii) a Gaussian mixture models of Gaussians, and (iii) samples in a circle, i.e., and for and fixed. We learn a GN mapping dimensional random Gaussian vectors to , passing through seven fully connected hidden layers of units each, activated by a Leaky ReLU function with slope . For this simple illustration, we sketch all datasets to a sketch of size . We found experimentally from a few trials that setting to a folded Gaussian distribution (see [keriven2016GMMestimation]) of scale is appropriate to draw the frequencies . From those sketches, we then trained our generators according to (6), using the keras framework. We fixed the number of generated samples to , which we split into mini-batches of samples when computing the gradient.
Fig. 2 compares the densities of generated samples and re-generated samples after the training (from the known densities) through their 2-D histograms. Note that while the datasets are simplistic, we restricted the training time to a few minutes and, except for the frequency distribution, no hyper-parameter tuning was performed. Despite a few outliers and missing probability masses, the visual proximity of the histograms proves the capacity of our method to learn complex 2-D distributions. Our code and further experiments are available at https://github.com/schellekensv/CL-GN.
We proposed and tested a method that incorporates compressive learning ideas into generative network training from the Maximum Mean Discrepancy metric. When dealing with large-scale datasets, our approach is potentially orders of magnitude faster than exact MMD-based learning.
However, to embrace higher-dimensional applications (e.g., for image restorations or large scale inverse problems), future works will need to (i) devise efficient techniques to adjust the kernel (i.e., the frequency distribution ) to the dataset , and (ii) determine theoretically the required sketch size in function of the dataset distribution . Concerning the choice of the kernel, a promising direction consists in tuning its Fourier transform directly from a lightweight sketch [keriven2016GMMestimation]. As for the required sketch size, this problem certainly relates to measuring the “complexity” of the true generating density , and to the general open question of why over-parametrized deep neural networks generalize so well.