Deep Probabilistic Ensembles: Approximate Variational Inference through KL Regularization
In this paper, we introduce Deep Probabilistic Ensembles (DPEs), a scalable technique that uses a regularized ensemble to approximate a deep Bayesian Neural Network (BNN). We do so by incorporating a KL divergence penalty term into the training objective of an ensemble, derived from the evidence lower bound used in variational inference. We evaluate the uncertainty estimates obtained from our models for active learning on visual classification, consistently outperforming baselines and existing approaches.
Modeling uncertainty for deep neural networks has a wide range of potential applications: it can provide important information about the reliability of predictions, or better strategies for labeling data in order to improve performance. Bayesian methods, which provide an approach to do this, have recently gained momentum, and are beginning to find more widespread use in practice graves2011practical (); blundell2015weight (); gal2015bayesian (); kendall2017uncertainties ().
The formulation of a Bayesian Neural Network (BNN) involves placing a prior distribution over all the parameters of a network, and obtaining the posterior given the observed data neal1995bayesian (). Existing BNN training approximations limit their applicability, since they do not specifically address the fact that deep BNNs on large datasets are far harder to optimize than deterministic networks, and require extensive parameter tuning to provide good performance and uncertainty estimates osband2016risk (). Furthermore, estimating uncertainty in BNNs requires drawing a large number samples at test time (typically sequentially), which can be extremely computationally demanding.
In practice, a common approach to estimating uncertainty is based on ensembles lakshminarayanan2017simple (); beluch2018power (); geifman2018boosting (). In this paper, we use ensembles to approximate BNNs, by introducing a form of regularization to the ensemble training objective. The KL divergence penalty term we use is derived from variational inference, a popular technique used to train BNNs blei2016variational ().
2 Deep Probabilistic Ensembles
For inference in a BNN, we can consider the weights to be latent variables drawn from our prior distribution, . These weights relate to the observed dataset through the likelihood, . We aim to compute the posterior that best explains the observed data. Variational inference involves restricting ourselves to a family of distributions over the latent variables, and optimizing for the member of this family that is closest to the true posterior in terms of KL divergence.
Simplifying this leads to a minimization objective, which is the negative Evidence Lower Bound (ELBO) blei2016variational (),
The first term in this objective is the KL divergence between the distribution of weights in the approximated BNN and the prior distribution . The second term is the expected negative log likelihood (NLL) of the data based on the current parameters . Since the expectation is over all possible assignments of , we seek to make any generic deterministic network sampled from the BNN as close a fit to the data as possible.
The optimization difficulty in variational inference arises partly due to fragile co-adaptations that exist between different parameters in deterministic networks, which are crucial to their performance yosinski2014transferable (). Features typically interact with each other in a complex way, such that the optimal setting for certain parameters is highly dependent on specific configurations of the other parameters in the network. The variance in our weight distributions prevents the BNN from exploiting co-adaptations during optimization, making it harder to optimize a BNN to match the performance of a deterministic network with the same number of parameters through standard variational inference or certain approximations to it blundell2015weight ().
KL regularization. We observe that an ensemble of networks perfectly exploits the co-adaptations in parameters, as each member of the ensemble is optimized independently. We can rewrite the minimization objective from Eq. 2 for an ensemble as follows:
where is the training data, is the number of models in the ensemble, is the cross-entropy loss for classification, refers to the parameters of the model , and is our KL regularization penalty over the joint set of all parameters . By averaging the loss over each independent model, we are calculating the expectation of the ELBO’s NLL term over only the current ensemble configurations, a subset of all possible assignments of . This is the main distinction between our approach and traditional variational inference.
The standard approach to training neural networks involves independently regularizing the parameters of each ensemble with or regularization terms. We instead apply the KL divergence term in Eq. 2 to the objective as a regularization penalty , to the distribution of values of a given parameter over all members in the ensemble. If we choose the family of Gaussian functions for and , this term can be analytically computed by assuming mutual independence between the network parameters and factoring the term into individual Gaussians. The KL divergence between two Gaussians with means and , standard deviations and is given by
We sum up this penalty over all the parameters of the ensemble, using a scaling term to balance the regularizer with the likelihood loss. can be obtained by removing the terms independent of and substituting for the chosen values of and from our prior into Eq. 4.
As a prior, we use the network initialization technique proposed in he2015delving (). More specifically, for convolutional layers with the ReLU activation, we use zero-centered Gaussian distributions, with variances inversely proportional to the number of kernel parameters. We use fixed variance Gaussians for the batch normalization parameters, with mean 1 for the weights and 0 for the biases. For each parameter in a convolutional layer of dimensions ,
The first term prevents extremely large variances compared to the prior, so the ensemble members do not diverge completely from each other. The second term heavily penalizes variances less than the prior, promoting diversity between members. The third term closely resembles weight decay, keeping the mean of the weights close that of the prior, especially when their variance is also low.
Active learning allows a model to choose the data from which it learns, by iteratively using the partially trained model to decide which examples to annotate from a large unlabeled dataset cohn1994active (). In our approach, the data with the highest entropy for the average prediction across the ensemble members is added to the training set from the unlabeled pool. We experiment with active learning on the CIFAR dataset, which has two object classification tasks over natural images: one coarse-grained over 10 classes and one fine-grained over 100 classes krizhevsky2009learning ().
We use eight models in our ensembles, each a pre-activation ResNet-18 he2016deep (). Our first experiment involves initializing models by training with a random subset of 4% of the dataset, and then re-training the model at three additional intervals after adding more data (at 8%, 16% and 32% of the data). We evaluate three approaches for adding data to the training set, (1) Random: pick the required percentage of data through random sampling; (2) Ensemble: pick the samples with highest average entropy across our eight models with standard regularization for each model; and (3) Deep Probabilistic Ensemble (DPE): pick samples with the highest average entropy across our eight models jointly regularized with our KL regularization. Our results are shown in Table 1. Our figures correspond to the mean accuracy of 3 experimental trials. Ensemble based active learning significantly outperforms random baselines across all experiments, and DPEs further improve on these results.
In our second experiment, we compare approaches that use different base models, but the same amount of labeled data (20% of the dataset), focusing on CIFAR-10 to compare with existing results. For our models, we start at 4% of the data and then re-train the model at four additional intervals (at 8%, 12%, 16% and 20%). Here, we include the results for (1) the DPE, (2) a deterministic network (with the same architecture, and regularization), (3) results from sener2018active (), which uses a geometry-based method called core-set selection with a single VGG-16, and (4) the current active learning state-of-the-art beluch2018power (), which uses an ensemble of DenseNet-121 models. Our approach gets far higher relative performance compared to its upper bound, as shown in Table 2.
|Task||Data Sampling||Accuracy @8%||Accuracy @16%||Accuracy @32%|
|Random||80.60 (84.66)||86.80 (91.18)||91.08 (95.67)|
|CIFAR-10||Ensemble||82.41 (86.56)||90.05 (94.59)||94.13 (98.87)|
|DPE (Ours)||82.88 (87.06)||90.15 (94.70)||94.33 (99.09)|
|Random||39.57 (50.18)||54.92 (69.64)||66.65 (84.51)|
|CIFAR-100||Ensemble||40.49 (51.34)||56.89 (72.14)||69.68 (88.36)|
|DPE (Ours)||40.87 (51.83)||56.94 (72.20)||70.12 (88.92)|
|Method||Accuracy @20%||Accuracy @100%||Relative Performance|
|Core-set sener2018active ()||74%||90%||82.2%|
|Ensemble beluch2018power ()||85%||95.5%||89%|
In this paper, we introduced DPEs, which perform approximate variational inference for BNNs by training ensembles with our novel KL regularization scheme. Our results demonstrate that DPEs improve performance on active learning tasks over strong baselines like vanilla ensembles and state-of-the-art active learning techniques on image data. We argue that they do so by providing better uncertainty information on unlabeled data, and look forward to future work with these models for any downstream tasks requiring precise uncertainty information.
-  William H. Beluch, Tim Genewein, Andreas Nürnberger, and Jan M. Köhler. The power of ensembles for active learning in image classification. In CVPR, 2018.
-  D. M. Blei, A. Kucukelbir, and J. D. McAuliffe. Variational Inference: A Review for Statisticians. ArXiv e-prints, 2016.
-  Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural networks. In ICML, 2015.
-  David Cohn, Les Atlas, and Richard Ladner. Improving generalization with active learning. Machine Learning, 1994.
-  Y. Gal and Z. Ghahramani. Bayesian Convolutional Neural Networks with Bernoulli Approximate Variational Inference. ArXiv e-prints, 2015.
-  Y. Geifman, G. Uziel, and R. El-Yaniv. Boosting Uncertainty Estimation for Deep Neural Classifiers. ArXiv e-prints, 2018.
-  Alex Graves. Practical variational inference for neural networks. In NIPS. 2011.
-  Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In ICCV, 2015.
-  Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, 2016.
-  Alex Kendall and Yarin Gal. What uncertainties do we need in bayesian deep learning for computer vision? In NIPS, 2017.
-  Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.
-  Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In NIPS, 2017.
-  Radford M. Neal. Bayesian learning for neural networks. PhD thesis, University of Toronto, 1995.
-  Ian Osband. Risk versus uncertainty in deep learning: Bayes, bootstrap and the dangers of dropout. In NIPS Workshops, 2016.
-  Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set approach. ICLR, 2018.
-  Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? In NIPS, 2014.