Definition 1 (Minimal Achievable Sufficient Statistic).
Abstract

We introduce Minimal Achievable Sufficient Statistic (MASS) Learning, a training method for machine learning models that attempts to produce minimal sufficient statistics with respect to a class of functions (e.g. deep networks) being optimized over. In deriving MASS Learning, we also introduce Conserved Differential Information (CDI), an information-theoretic quantity that — unlike standard mutual information — can be usefully applied to deterministically-dependent continuous random variables like the input and output of a deep network. In a series of experiments, we show that deep networks trained with MASS Learning achieve competitive performance on supervised learning, regularization, and uncertainty quantification benchmarks.

oddsidemargin has been altered.
marginparsep has been altered.
topmargin has been altered.
marginparwidth has been altered.
marginparpush has been altered.
paperheight has been altered.
The page layout violates the ICML style. Please do not change the page layout, or include packages like geometry, savetrees, or fullpage, which change it for you. We’re not able to reliably undo arbitrary changes to the style. Please remove the offending package(s), or layout-changing commands and try again.

 

Minimal Achievable Sufficient Statistic Learning

 

Milan Cvitkovic0  Günther Koliander0 


footnotetext: 1AUTHORERR: Missing \icmlaffiliation. 2AUTHORERR: Missing \icmlaffiliation. . Correspondence to: Milan Cvitkovic <mcvitkov@caltech.edu>.  
Proceedings of the International Conference on Machine Learning, Long Beach, California, PMLR 97, 2019. Copyright 2019 by the author(s).
\@xsect

The representation learning approach to machine learning focuses on finding a representation of an input random variable that is useful for predicting a random variable (Goodfellow et al., 2016).

What makes a representation “useful” is much debated, but a common assertion is that should be a minimal sufficient statistic of for (Adragni, Kofi P. & Cook, R. Dennis, 2009; Shamir et al., 2010; James et al., 2017; Achille & Soatto, 2018b). That is:

  1. should be a statistic of . This means for some function .

  2. should be sufficient for . This means .

  3. Given that is a sufficient statistic, it should be minimal with respect to . This means for any measurable, non-invertible function , is no longer sufficient for .111This is not the most common phrasing of statistical minimality, but we feel it is more understandable. For the equivalence of this phrasing and the standard definition see Supplementary Material id1.

In other words: a minimal sufficient statistic is a random variable that tells you everything about you could ever care about, but if you do any irreversible processing to , you are guaranteed to lose some information about .

Minimal sufficient statistics have a long history in the field of statistics (Lehmann & Scheffe, 1950; Dynkin, 1951). But the minimality condition (3, above) is perhaps too strong to be useful in machine learning, since it is a statement about any function , rather than about functions in a practical hypothesis class like the class of deep neural networks.

Instead, in this work we consider minimal achievable sufficient statistics: sufficient statistics that are minimal among some particular set of functions.

Definition 1 (Minimal Achievable Sufficient Statistic).

Let be a sufficient statistic of for . is minimal achievable with respect to a set of functions if and for any Lipschitz continuous, non-invertible function , is no longer sufficient for .

\@xsect
  • We introduce Conserved Differential Information (CDI), an information-theoretic quantity that, unlike mutual information, is meaningful for deterministically-dependent continuous random variables, such as the input and output of a deep network.

  • We introduce Minimal Achievable Sufficient Statistic Learning (MASS Learning), a training objective based on CDI for finding minimal achievable sufficient statistics.

  • We provide empirical evidence that models trained by MASS Learning achieve competitive performance on supervised learning, regularization, and uncertainty quantification benchmarks.

\@xsect

Before we present MASS Learning, we need to introduce Conserved Differential Information (CDI), on which MASS Learning is based.

CDI is an information-theoretic quantity that addresses an oft-cited issue in machine learning (Bell & Sejnowski, 1995; Amjad & Geiger, 2018; Saxe et al., 2018; Nash et al., 2018; Goldfeld et al., 2018), which is that for a continuous random variable and a continuous, non-constant function , the mutual information is infinite. (See Supplementary Material id1 for details.) This makes unsuitable for use in a learning objective when is, for example, a standard deep network.

The infinitude of has been circumvented in prior works by two strategies. One is discretize and (Tishby & Zaslavsky, 2015; Shwartz-Ziv & Tishby, 2017), though this is controversial (Saxe et al., 2018). Another is to use a random variable with distribution as the representation of rather than using itself as the representation (Alemi et al., 2017; Kolchinsky et al., 2017; Achille & Soatto, 2018b). In this latter approach, is usually implemented by adding noise to a deep network that takes as input.

These are both reasonable strategies for avoiding the infinitude of . But another approach would be to derive a new information-theoretic quantity that is better suited to this situation. To that end we present Conserved Differential Information:

Definition 2.

For a continuous random variable taking values in and a Lipschitz continuous function , the Conserved Differential Information (CDI) is

(1)

where denotes the differential entropy

and is the Jacobian determinant of

with the Jacobian matrix of at .

Readers familiar with normalizing flows (Rezende & Mohamed, 2015) or Real NVP (Dinh et al., 2017) will note that the Jacobian determinant used in those methods is a special case of the Jacobian determinant in the definition of CDI. This is because normalizing flows and Real NVP are based on the change of variables formula for invertible mappings, while CDI is based in part on the more general change of variables formula for non-invertible mappings. More details on this connection are given in Supplementary Material id1. The mathematical motivation for CDI based on the recent work of Koliander et al. (2016) is provided in Supplementary Material id1. Figure 1 gives a visual example of what CDI measures about a function.

Figure 1: CDI of two functions and of the random variable . Even though the random variables and have the same distribution, is different from . This is because is an invertible function, while is not. CDI quantifies, roughly speaking, “how non-invertible” is.

The conserved differential information between continuous, deterministically-dependent random variables behaves a lot like mutual information does on discrete random variables. For example, when is invertible, , just like with the mutual information between discrete random variables. Most importantly for our purposes, though, obeys the following data processing inequality:

Theorem 1 (CDI Data Processing Inequality).

For Lipschitz continuous functions and with the same output space,

with equality if and only if is invertible almost everywhere.

The proof is in Supplementary Material id1.

\@xsect

With CDI and its data processing inequality in hand, we can give the following optimization-based characterization of minimal achievable sufficient statistics:

Theorem 2.

Let be a continuous random variable, be a discrete random variable, and be any set of Lipschitz continuous functions with a common output space (e.g., different parameter settings of a deep network). If

then is a minimal achievable sufficient statistic of for with respect to .

Proof.

First note the following lemma (Cover & Thomas, 2006).

Lemma 1.

is a sufficient statistic for a discrete random variable if and only if .

Lemma 1 guarantees that any satisfying the conditions in Theorem 2 is sufficient. Suppose such an was not minimal achievable. Then by Definition 1 there would exist a non-invertible, Lipschitz continuous such that was sufficient. But by Theorem 1, it would then also be the case that , which would contradict minimizing . ∎

We can turn Theorem 2 into a learning objective over functions by relaxing the strict constraint into a Lagrangian formulation with Lagrange multiplier for :

The larger the value of , the more our objective will encourage minimality over sufficiency. We can then simplify this formulation using the identity , which gives us the following optimization objective:

(2)

We refer to minimizing this objective as MASS Learning.

\@xsect

In practice, we are interested in using MASS Learning to train a deep network with parameters using a finite dataset of datapoints sampled from the joint distribution of and . To do this, we introduce a parameterized variational approximation . Using , we minimize the following empirical upper bound to :

where the quantity is computed as and the quantity is computed with Bayes rule as . When is discrete and takes on finitely many values, as in classification problems, and when we choose a variational distribution that is differentiable with respect to (e.g. a multivariate Gaussian), then we can minimize using stochastic gradient descent (SGD).

To perform classification using our trained network, we use the learned variational distribution and Bayes rule:

Computing the term in for every sample in an SGD minibatch is too expensive to be practical. For , doing so would require on the order of times more operations than in standard training of deep networks by, since computing the term involves computing the full Jacobian matrix of the network, which, in our implementation, involves performing backpropagations. Thus to make training tractable, we use a subsampling strategy: we estimate the term using only a fraction of the datapoints in a minibatch. In practice, we have found this subsampling strategy to not noticeably alter the numerical value of the term during training.

Subsampling for the term results in a significant training speedup, but it must nevertheless be emphasized that, even with subsampling, our implementation of MASS Learning is roughly eight times as slow as standard deep network training. (Unless , in which case the speed is the same.) This is by far the most significant drawback of (our implementation of) MASS Learning. There are many easier-to-compute upper bounds or estimates of that one could use to make MASS Learning faster, and one could also potentially find non-invertible network architectures which admit more efficiently computable Jacobians, but we do not explore these options in this work.

\@xsect\@xsect

The well-studied Information Bottleneck learning method (Tishby et al., 2000; Tishby & Zaslavsky, 2015; Strouse & Schwab, 2015; Alemi et al., 2017; Saxe et al., 2018; Amjad & Geiger, 2018; Goldfeld et al., 2018; Kolchinsky et al., 2019; Achille & Soatto, 2018b; a) is based on minimizing the Information Bottleneck Lagrangian

for , where is the representation whose conditional distribution one is trying to learn.

The learning objective can be motivated based on pure information-theoretic elegance. But some works like (Shamir et al., 2010) also point out the connection between the objective and minimal sufficient statistics, which is based on the following theorem:

Theorem 3.

Let be a discrete random variable drawn according to a distribution determined by the discrete random variable . Let be the set of deterministic functions of to any target space. Then is a minimal sufficient statistic of for if and only if

The objective can then be thought of as a Lagrangian relaxation of the optimization problem in this theorem.

Theorem 3 only holds for discrete random variables. For continuous it holds only in the reverse direction, so minimizing for continuous has no formal connection to finding minimal sufficient statistics, not to mention minimal achievable sufficient statistics. See Supplementary Material id1 for details.

Nevertheless, the optimization problems in Theorem 2 and Theorem 3 are extremely similar, relying as they both do on Lemma 1 for their proofs. And the idea of relaxing the optimization problem in Theorem 2 into a Lagrangian formulation to get is directly inspired by the Information Bottleneck. So while MASS Learning and Information Bottleneck learning entail different network architectures and loss functions, there is an Information Bottleneck flavor to MASS Learning.

\@xsect

The presence of the term in is reminiscent of the contrastive autoencoder (Rifai et al., 2011) and Jacobian Regularization literature (Sokolic et al., 2017; Ross & Doshi-Velez, 2018; Varga et al., 2017; Novak et al., 2018; Jakubovitz & Giryes, 2018). Both these literatures suggest that minimizing , where is the Jacobian matrix, seems to improve generalization and adversarial robustness.

This may seem paradoxical at first, since by applying the AM-GM inequality to the eigenvalues of we have

and is being maximized by . So might seem to be optimizing for worse generalization according to the Jacobian regularization literature. However, the entropy term in strongly encourages minimizing . So overall seems to be seeking the right balance of sensitivity (dependent on the value of ) in the network to its inputs, which is precisely in alignment with what the Jacobian regularization literature suggests.

\@xsect

In this section we compare MASS Learning to other approaches for training deep networks. Code to reproduce all experiments is available online.222https://github.com/mwcvitkovic/MASS-Learning Full details on all experiments is in Supplementary Material id1.

We use the abbreviation “SoftmaxCE” to refer to the standard approach of training deep networks for classification problems by minimizing the softmax cross entropy loss

where is the th element of the softmax function applied to the outputs of the network’s last linear layer. As usual, is taken to be the network’s estimate of .

We also compare against the Variational Information Bottleneck method (Alemi et al., 2017) for representation learning, which we abbreviate as “VIB”.

We use two networks in our experiments. “SmallMLP” is a feedforward network with two fully-connected layers of 400 and 200 hidden units, respectively, both with elu nonlinearities (Clevert et al., 2015). “ResNet20” is the 20-layer residual network of He et al. (2016).

We performed all experiments on the CIFAR-10 dataset (Krizhevsky, 2009) and implemented all experiments using PyTorch (Paszke et al., 2017).

\@xsect

We first confirm that networks trained by MASS Learning can make accurate predictions in supervised learning tasks. We also compare the classification accuracy of networks trained on varying amounts of data to see whether MASS Learning regularizes networks and improves their generalization performance.

Classification accuracies for the SmallMLP network are shown in Table 1, and for the ResNet20 network in Table 2. For the SmallMLP network, MASS Learning performs comparably to SoftmaxCE and VIB, but does not appear to offer any performance benefits. For the larger ResNet20 network, the results show that while MASS Learning outperforms SoftmaxCE and VIB training, these improvements do not seem to be due to the MASS loss itself, since the same performance improvements are obtained even when the and terms in the MASS loss are set to 0 (i.e. the case when ).

This suggests that it is the use of the variational distribution to compute the output of the network, rather than the MASS Learning approach, that is providing the benefit. This is an interesting finding, and worthy of further study, but does not suggest an advantage to using the full MASS Learning method if one is concerned with accuracy or regularization.


Method Training Set Size
2500 10,000 40,000

SoftmaxCE
SoftmaxCE, WD
SoftmaxCE, D
VIB, =
VIB, =
VIB, =
VIB, =, D
VIB, =, D
VIB, =, D
MASS, =
MASS, =
MASS, =
MASS, =
MASS, =, D
MASS, =, D
MASS, =, D
MASS, =, D

Table 1: Test-set classification accuracy (percent) on CIFAR-10 dataset using the SmallMLP network trained by various methods. Full experiment details are in Supplementary Material id1. Values are the mean classification accuracy over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened accuracies are those for which the maximum observed mean accuracy in the column was within one standard deviation. WD is weight decay; D is dropout.

Method Training Set Size
2500 10,000 40,000

SoftmaxCE
VIB, =
VIB, =
VIB, =
VIB, =
MASS, =
MASS, =
MASS, =
MASS, =

Table 2: Test-set classification accuracy (percent) on CIFAR-10 dataset using the ResNet20 network trained by various methods. No data augmentation or learning rate scheduling was used — full details in Supplementary Material id1. Values are the mean classification accuracy over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened accuracies are those for which the maximum observed mean accuracy in the column was within one standard deviation.
\@xsect

We also evaluate the ability of networks trained by MASS Learning to properly quantify their uncertainty about their predictions. We assess uncertainty quantification in two ways: using proper scoring rules (Lakshminarayanan et al., 2017), which are scalar measures of how well a network’s predictive distribution is calibrated, and by assessing performance on an out-of-distribution (OOD) detection task.

Tables 3 through 8 show the uncertainty quantification performance of networks according to two proper scoring rules: the Negative Log Likelihood (NLL) and the Brier Score. The entropy and test accuracy of the predictive distributions are also given, for reference.

For the SmallMLP network in tables 3, 4, and 5, VIB provides the best combination of high accuracy and low NLL and Brier score across all sizes of training set. For the ResNet20 network in tables 6, 7, and 8, MASS Learning provides the best combination of accuracy and proper scoring rule performance. The NLL values for ResNet20 strongly indicate — as our theory suggests — that MASS Learning with larger leads to better calibrated network predictions. The Brier Score values show the same trend, though not decisively. Thus as measured by proper scoring rules, MASS Learning can significantly improve the calibration of a network’s predictions while simultaneously improving its accuracy.

Tables 9 through 14 show metrics for performance on an OOD detection task where the network predicts not just the class of the input image, but whether the image is from its training distribution (CIFAR-10 images) or from another distribution (SVHN images (Netzer et al., 2011)). Following Hendrycks & Gimpel (2017) and Alemi et al. (2018), the metrics we report for this task are the Area under the ROC curve (AUROC) and Average Precision score (APR). APR depends on whether the network is tasked with identifying in-distribution or out-of-distribution images; we report values for both cases as APR In and APR Out, respectively.

There are different detection methods that networks can use to identify OOD inputs. One way, applicable to all training methods, is to use the entropy of the predictive distribution : larger entropy suggests the input is OOD. For networks trained by MASS Learning, the variational distribution is a natural OOD detector: a small value of suggests the input is OOD. For networks trained by SoftmaxCE, a distribution can be learned by MLE on the training set and used to detect OOD inputs in the same way.

For the SmallMLP network in tables 9, 10, and 11 and the ResNet20 network in tables 12, 13, and 14, MASS Learning again performs comparably or better than SoftmaxCE and VIB in all cases, but with giving performance comparable to MASS Learning with .


Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
SoftmaxCE, WD
SoftmaxCE, D
VIB, =
VIB, =
VIB, =
VIB, =, D
VIB, =, D
VIB, =, D
MASS, =
MASS, =
MASS, =
MASS, =
MASS, =, D
MASS, =, D
MASS, =, D
MASS, =, D

Table 3: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the SmallMLP network trained on 40,000 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. WD is weight decay; D is dropout. Lower values are better.

Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
SoftmaxCE, WD
SoftmaxCE, D
VIB, =
VIB, =
VIB, =
VIB, =, D
VIB, =, D
VIB, =, D
MASS, =
MASS, =
MASS, =
MASS, =
MASS, =, D
MASS, =, D
MASS, =, D
MASS, =, D

Table 4: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the SmallMLP network trained on 10,000 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. WD is weight decay; D is dropout. Lower values are better.

Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
SoftmaxCE, WD
SoftmaxCE, D
VIB, =
VIB, =
VIB, =
VIB, =, D
VIB, =, D
VIB, =, D
MASS, =
MASS, =
MASS, =
MASS, =
MASS, =, D
MASS, =, D
MASS, =, D
MASS, =, D

Table 5: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the SmallMLP network trained on 2,500 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. WD is weight decay; D is dropout. Lower values are better.

Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
VIB, =
VIB, =
VIB, =
VIB, =
MASS, =
MASS, =
MASS, =
MASS, =

Table 6: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the ResNet20 network trained on 40,000 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. Lower values are better.

Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
VIB, =
VIB, =
VIB, =
VIB, =
MASS, =
MASS, =
MASS, =
MASS, =

Table 7: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the ResNet20 network trained on 10,000 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. Lower values are better.

Method Test Accuracy NLL Brier Score Entropy

SoftmaxCE
VIB, =
VIB, =
VIB, =
VIB, =
MASS, =
MASS, =
MASS, =
MASS, =

Table 8: Uncertainty quantification metrics (proper scoring rules) on CIFAR-10 using the ResNet20 network trained on 2,500 datapoints. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the minimum observed mean value in the column was within one standard deviation. Lower values are better.

Training Method Test Accuracy Detection Method AUROC APR In APR Out

SoftmaxCE
Entropy
SoftmaxCE, WD Entropy
SoftmaxCE, D Entropy
VIB, = Entropy
Rate
VIB, = Entropy
Rate
VIB, = Entropy
Rate
VIB, =, D Entropy
Rate
VIB, =, D Entropy
Rate
VIB, =, D Entropy
Rate
MASS, = Entropy
MASS, = Entropy
MASS, = Entropy
MASS, = Entropy
MASS, =, D Entropy
MASS, =, D Entropy
MASS, =, D Entropy
MASS, =, D Entropy

Table 9: Out-of-distribution detection metrics for SmallMLP network trained on 40,000 CIFAR-10 images, with SVHN as the out-of-distribution examples. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the maximum observed mean value in the column was within one standard deviation. WD is weight decay; D is dropout. Higher values are better.

Training Method Test Accuracy Detection Method AUROC APR In APR Out

SoftmaxCE
Entropy
SoftmaxCE, WD Entropy
SoftmaxCE, D Entropy
VIB, = Entropy
Rate
VIB, = Entropy
Rate
VIB, = Entropy
Rate
VIB, =, D Entropy
Rate
VIB, =, D Entropy
Rate
VIB, =, D Entropy
Rate
MASS, = Entropy
MASS, = Entropy
MASS, = Entropy
MASS, = Entropy
MASS, =, D Entropy
MASS, =, D Entropy
MASS, =, D Entropy
MASS, =, D Entropy

Table 10: Out-of-distribution detection metrics for SmallMLP network trained on 10,000 CIFAR-10 images, with SVHN as the out-of-distribution examples. Full experiment details are in Supplementary Material id1. Values are the mean over 4 training runs with different random seeds, plus or minus the standard deviation. Emboldened values are those for which the maximum observed mean value in the column was within one standard deviation. WD is weight decay; D is dropout. Higher values are better.

Training Method Test Accuracy Detection Method AUROC APR In APR Out

SoftmaxCE
Entropy