Uncertainty Estimation via Stochastic Batch Normalization

Uncertainty Estimation via Stochastic Batch Normalization

Andrei Atanov, Arsenii Ashukha, Dmitry Molchanov, Kirill Neklyudov, Dmitry Vetrov
National Research University Higher School of Economics, University of Amsterdam, Yandex
andrewatanov@yandex.ru, ars.ashuh@gmail.com
dmolch111@gmail.com, k.necludov@gmail.com, vetrodim@gmail.com

In this work, we investigate Batch Normalization technique and propose its probabilistic interpretation. We propose a probabilistic model and show that Batch Normalization maximazes the lower bound of its marginalized log-likelihood. Then, according to the new probabilistic model, we design an algorithm which acts consistently during train and test. However, inference becomes computationally inefficient. To reduce memory and computational cost, we propose Stochastic Batch Normalization – an efficient approximation of proper inference procedure. This method provides us with a scalable uncertainty estimation technique. We demonstrate the performance of Stochastic Batch Normalization on popular architectures (including deep convolutional architectures: VGG-like and ResNets) for MNIST and CIFAR-10 datasets.



1 Introduction

Deep Neural Networks have achieved state-of-the-art quality on many problems and are successfully integrated in real-life scenarios: semantic segmentation, object detection and scene recognition, to name but a few. Usually the quality of a model is measured in terms of accuracy, however, accurate uncertainty estimation is also crucial for real-life decision-making applications, such as self-driving systems and medical diagnostic. Despite high accuracy rate, DNNs are prone to overconfidence even on out-of-domain data.

The Bayesian framework lends itself well to uncertainty estimation (MacKay, 1992), but exact Bayesian inference is intractable for large models such as DNNs. To address this issue, a number of approximation inference techniques have been proposed recently (Welling & Teh, 2011; Hoffman et al., 2013). It has been shown that Dropout, a well-known regularization technique (Srivastava et al., 2014), can be treated as a special case of stochastic variational inference (Kingma et al. (2015); Molchanov et al. (2017)). Also Gal & Ghahramani (2015) showed that stochasticity induced by Dropout can provide well-calibrated uncertainty estimation for DNNs. Multiplicative Normalizing Flows Louizos & Welling (2017) is another approximation technique that produces great uncertainty estimation. However, such complex method is hard to scale to very deep convolutional architectures. Moreover, recently proposed Residual Network (He et al., 2015) with more than a hundred layers does not have any noise inducing layers such as Dropout. This type of layer leads to a significant accuracy degradation (He et al., 2015). This problem can be addressed by non-Bayesian Deep Ensembles method (Lakshminarayanan et al., 2017), which provides competitive uncertainty estimation, but it requires to store several separate models and perform forward passes through all of them to make prediction.

Batch Normalization (Ioffe & Szegedy, 2015) is an essential part of very deep convolutional architectures. In our work, we treat Batch Normalization as a stochastic layer and propose a way to ensemble batch-normalized networks. The straightforward technique, however, ends up with high memory and computational cost. We, therefore, propose Stochastic Batch Normalization (SBN) — an efficient and scalable approximation technique. We show the performance of our method on out-of-domain uncertainty estimation problem for deep convolutional architectures including VGG-like, ResNet and LeNet-5 on MNIST and CIFAR10 datasets. We also demonstrate that SBN successfully extends Dropout and Deep Ensembles methods.

2 Method

We consider a supervised learning problem, with a dataset . The goal is to train the parameters of the predictive likelihood , modelled by a neural network. To solve this problem stochastic optimization methods with a mini-batch gradient estimator usually are used.

Batch Normalization Batch Normalization attempts to preserve activations of all layers with zero mean and unit variance. In order to do that it uses the mean and variance over the mini-batch during training and accumulated statistics on the inference phase:


where are the trainable Batch Normalization parameters (scale and shift) and is a small constant, needed for numerical stability. Note that during training mean and variance are computed over a randomly picked batch (), while during testing the exponentially smoothed statistics () are used. We further address this inconsistency by proposed probabilistic model.

Batch Normalization: Probabilistic View Note from (1) that forward pass through the batch-normalized network depends not only on but on the entire batch as well. This dependency can be reinterpreted in terms of mini-batch statistics :


where is a batch without . Due to the stochastic choice of mini-batches during training, for a fixed is a random variable, so mini-batch statistics can be treated as a random variables. The conditional distribution is the product of two Dirac delta functions, centered at and , since statistics are deterministic functions of the mini-batch, and the distribution of mean and variance given is an expectation over mini-batch distribution. During inference we average the distribution over the normalization statistics:


Connection to Batch Normalization In Sec. A we show that during training Batch Normalization (1) performs the unbiased one-sample MC estimation of a gradient of a lower bound to the marginalized log-likelihood (3). Thus, such probabilistic model corresponds to Batch Normalization during training. However, on test phase Batch Normalization uses exponentially smoothed statistics , which can be seen as a biased approximation of (3):

Straightforward MC averaging can be used for better unbiased estimation of (3), however, it is inefficient in practiсe. Indeed, to draw one sample from the distribution over statistics (3) we need to pass an entire mini-batch through the network. So, to make MC averaging for single test object, we need to perform several forward passes with different mini-batches sampled from the training data. To address this drawback we propose Stochastic Batch Normalization.

Stochastic Batch Normalization To address memory and computational cost of straightforward MC estimation, we propose to approximate the distribution of Batch Normalization statistics with a fully-factorized parametric approximation . We parameterize and in the following way:


Such approximation works well in practice. In Sec. B we show that it accurately fits the real marginals. Since approximation no longer depends on the training data, samples for each layer can be computed without passing the entire batch through the network and it is possible to make prediction in an efficient way.

To adjust parameters we minimize the KL-divergence between distribution induced by Batch Normalization (3) and our approximation for each object:

Since belongs to the exponential family, this minimization problem is equal to moment matching and does not require gradients computation. In our implementation we simply use exponential smoothing to approximate the sufficient statistics of mean and variance distributions. It can be don for any pre-trained batch-normalized network.

3 Experiments

Network Method Error% NLL
LeNet-5 MNIST SBN 0.53 0.05
Deep Ensembles 0.43 0.00
Dropout 0.49 0.00 0.015 0.000
Deep Ensembles 5.18 0.00 5.23 0.00 0.154 0.002
Dropout 5.32 0.00 0.149 0.001
ResNet-18 CIFAR5 SBN
Deep Ensembles 3.37 0.00 0.110 0.004
Table 1: Test errors (%) and NLL scores for known classes. MNIST for LeNet-5 and CIFAR5 for VGG-11 and ResNet-18. SBN column correspond to methods with all Batch Normalization layers replaced by ours SBN.

We evaluate uncertainties on MNIST and CIFAR10 datasets using convolutional architectures. In order to apply Stochastic Batch Normalization to existing architectures we only need to update parameters of our approximation (4), which does not affect the training process at all. We show that SBN improves both Dropout and Deep Ensembles techniques in terms of out-of-domain uncertainty and test Negative Log-Likelihood (NLL), and maintains the same level of accuracy

Experimental Setup We compare our method with Dropout and Deep Ensembles. Since He et al. (2015) showed that ResNet does not perform well with any Dropout layer and suffers from instability, we did not include this method into consideration for ResNet architecture. For Deep Ensembles we trained 6 models for all architectures and did not use adversarial training (as suggested by Lakshminarayanan et al. (2017)) since this technique results in lower accuracy.

Uncertainty estimation on notMNIST For this experiment we trained LeNet-5 model on MNIST and evaluated the entropy of the predictive distribution on notMNIST, which is out-of-domain data for MNIST, and plot the empirical CDF on Fig. 0(a). We also report the test set accuracy and NLL scores, the results can be seen at Tab. 1.

Uncertainty estimation on CIFAR10 To show that our method scales to deep convolutional architectures well, we perform experiments on VGG-like and ResNet architectures. We split CIFAR10 dataset into two datasets (CIFAR5) and perform a similar experiment as for MNIST and notMNIST. We trained networks on randomly chosen 5 classes and evaluated predictive uncertainty on the remaining.

We observed that Stochastic Batch Normalization improves both Dropout and Deep Ensembles in terms of out-of-domain uncertainties and NLL score on test data (from the same domain) at the same level of accuracy. However, SBN itself ends up with the more overconfident predictive distribution in comparison to baselines Dropout and Deep Ensembles.

(a) Results for LeNet-5
(b) Results for VGG-11
(c) Results for ResNet-18
Figure 1: Empirical CDF of entropy for out-of-domain data. \subref*fig:lenet LeNet-5 on notMNIST, \subref*fig:vgg VGG-11 and \subref*fig:resnet ResNet-18 on five classes of CIFAR10, hidden during training. SBN corresponds to model with all Batch Normalization layers replaced by Stochastic Batch Normalization. The more to the right and the lower, the better.


This research is in part based on the work supported by Samsung Research, Samsung Electronics.


  • Gal & Ghahramani (2015) Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. arXiv:1506.02142, 2015.
  • He et al. (2015) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. CoRR, abs/1512.03385, 2015.
  • Hoffman et al. (2013) Matthew D. Hoffman, David M. Blei, Chong Wang, and John Paisley. Stochastic variational inference. Journal of Machine Learning Research, 14:1303–1347, 2013.
  • Ioffe & Szegedy (2015) Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. CoRR, abs/1502.03167, 2015.
  • Kingma et al. (2015) Diederik P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameterization trick. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett (eds.), Advances in Neural Information Processing Systems 28, pp. 2575–2583. Curran Associates, Inc., 2015.
  • Lakshminarayanan et al. (2017) Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.), Advances in Neural Information Processing Systems 30, pp. 6405–6416. Curran Associates, Inc., 2017.
  • Louizos & Welling (2017) Christos Louizos and Max Welling. Multiplicative normalizing flows for variational bayesian neural networks. In Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017, pp. 2218–2227, 2017.
  • MacKay (1992) David J. C. MacKay. A practical bayesian framework for backpropagation networks. Neural Comput., 4(3):448–472, May 1992. ISSN 0899-7667. doi: 10.1162/neco.1992.4.3.448.
  • Molchanov et al. (2017) Dmitry Molchanov, Arsenii Ashukha, and Dmitry Vetrov. Variational dropout sparsifies deep neural networks. arXiv preprint arXiv:1701.05369, 2017.
  • Srivastava et al. (2014) Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. J. Mach. Learn. Res., 15(1):1929–1958, January 2014. ISSN 1532-4435.
  • Welling & Teh (2011) Max Welling and Yee Whye Teh. Bayesian learning via stochastic gradient langevin dynamics. In Lise Getoor and Tobias Scheffer (eds.), ICML, pp. 681–688. Omnipress, 2011.

Appendix A Lower bound on marginal log-likelihood

In Sec. 2 we propose the probabilistic view on Batch Normalization which models marginal likelihood . In this section we show that conventional Batch Normalization actually optimizes a lower bound on marginal log-likelihood in such probabilistic model. So the goal is to train the model parameters  given training dataset . Using Maximum Likelihood approach we need to maximize the following objective :


However, the term is intractable due to the expectation over statistics. We, therefore, construct a lower bound of using the Jensen-Shannon inequality:


To use gradient-based optimization methods we need to compute gradient of w.r.t. parameters . Unfortunately, distribution over depends on and, therefore, we cannot propagate gradient through the expectation. However, we can use the definition of (3) and reparametrize expectation in terms of mini-batch distribution:

Since distribution over mini-batches does not depend on , we now can propagate the gradient through the expectation and use MC approximation for an unbiased estimation. During training Batch Normalization draws mini-batch of size and approximate the full gradient in the following way:

Note that Batch Normalization uses the same mini-batch to calculate statistics as for gradient estimation. Taking an expectation over mini-batch , we can actually see that such procedure performs an unbiased estimation of :

So Batch Normalization produces an unbiased gradient estimation of during training and can be seen as an approximation for inference in proposed probabilistic model.

Appendix B Statistics distribution approximation

For computational and memory efficiency we propose the following approximation for the real distribution over the batch statistics, induced by Batch Normalization:


According to our observation, the real distributions are unimodal Fig 2. Also the Central Limit Theorem implies that the means converge in distributions to Gaussians, therefore we model this distribution using a fully-factorized Gaussian. While the common choice for the variance is Gamma distribution, we choose the log-normal distribution, as it allows for a more tractable moment-matching. Also as we show in the plots, the log-normal distribution fits the data well.

To verify the right choice of parametric family we estimate an empirical marginal distributions over and for LeNet-5 architecture on MNIST dataset. To sample statistics from the real distribution we pass different mini-batches from training data through the network. We use Kernel Density Estimation to plot the empirical distribution. The results for convolutional and fully-connected layers of LeNet-5 can be seen at Fig. 2 and 3. It can be seen that the approximation (7) fits the real marginal distributions over very accurately.

(a) Distributions for LeNet-5 conv1
(b) Distributions for LeNet-5 conv2
Figure 2: The empirical marginal distribution over statistics (blue) for convolutional LeNet-5 layers and proposed approximation (green). Top row for mean distribution and bottom for variance.
Figure 3: The empirical marginal distribution over statistics (blue) for fully-connected LeNet-5 layer and the proposed approximation (green). Top row corresponds to the means, and the bottom row corresponds to the variances.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description