Uncertainty Estimation via Stochastic Batch Normalization
In this work, we investigate Batch Normalization technique and propose its probabilistic interpretation. We propose a probabilistic model and show that Batch Normalization maximazes the lower bound of its marginalized log-likelihood. Then, according to the new probabilistic model, we design an algorithm which acts consistently during train and test. However, inference becomes computationally inefficient. To reduce memory and computational cost, we propose Stochastic Batch Normalization – an efficient approximation of proper inference procedure. This method provides us with a scalable uncertainty estimation technique. We demonstrate the performance of Stochastic Batch Normalization on popular architectures (including deep convolutional architectures: VGG-like and ResNets) for MNIST and CIFAR-10 datasets.
Deep Neural Networks have achieved state-of-the-art quality on many problems and are successfully integrated in real-life scenarios: semantic segmentation, object detection and scene recognition, to name but a few. Usually the quality of a model is measured in terms of accuracy, however, accurate uncertainty estimation is also crucial for real-life decision-making applications, such as self-driving systems and medical diagnostic. Despite high accuracy rate, DNNs are prone to overconfidence even on out-of-domain data.
The Bayesian framework lends itself well to uncertainty estimation (MacKay, 1992), but exact Bayesian inference is intractable for large models such as DNNs. To address this issue, a number of approximation inference techniques have been proposed recently (Welling & Teh, 2011; Hoffman et al., 2013). It has been shown that Dropout, a well-known regularization technique (Srivastava et al., 2014), can be treated as a special case of stochastic variational inference (Kingma et al. (2015); Molchanov et al. (2017)). Also Gal & Ghahramani (2015) showed that stochasticity induced by Dropout can provide well-calibrated uncertainty estimation for DNNs. Multiplicative Normalizing Flows Louizos & Welling (2017) is another approximation technique that produces great uncertainty estimation. However, such complex method is hard to scale to very deep convolutional architectures. Moreover, recently proposed Residual Network (He et al., 2015) with more than a hundred layers does not have any noise inducing layers such as Dropout. This type of layer leads to a significant accuracy degradation (He et al., 2015). This problem can be addressed by non-Bayesian Deep Ensembles method (Lakshminarayanan et al., 2017), which provides competitive uncertainty estimation, but it requires to store several separate models and perform forward passes through all of them to make prediction.
Batch Normalization (Ioffe & Szegedy, 2015) is an essential part of very deep convolutional architectures. In our work, we treat Batch Normalization as a stochastic layer and propose a way to ensemble batch-normalized networks. The straightforward technique, however, ends up with high memory and computational cost. We, therefore, propose Stochastic Batch Normalization (SBN) — an efficient and scalable approximation technique. We show the performance of our method on out-of-domain uncertainty estimation problem for deep convolutional architectures including VGG-like, ResNet and LeNet-5 on MNIST and CIFAR10 datasets. We also demonstrate that SBN successfully extends Dropout and Deep Ensembles methods.
We consider a supervised learning problem, with a dataset . The goal is to train the parameters of the predictive likelihood , modelled by a neural network. To solve this problem stochastic optimization methods with a mini-batch gradient estimator usually are used.
Batch Normalization Batch Normalization attempts to preserve activations of all layers with zero mean and unit variance. In order to do that it uses the mean and variance over the mini-batch during training and accumulated statistics on the inference phase:
where are the trainable Batch Normalization parameters (scale and shift) and is a small constant, needed for numerical stability. Note that during training mean and variance are computed over a randomly picked batch (), while during testing the exponentially smoothed statistics () are used. We further address this inconsistency by proposed probabilistic model.
Batch Normalization: Probabilistic View Note from (1) that forward pass through the batch-normalized network depends not only on but on the entire batch as well. This dependency can be reinterpreted in terms of mini-batch statistics :
where is a batch without . Due to the stochastic choice of mini-batches during training, for a fixed is a random variable, so mini-batch statistics can be treated as a random variables. The conditional distribution is the product of two Dirac delta functions, centered at and , since statistics are deterministic functions of the mini-batch, and the distribution of mean and variance given is an expectation over mini-batch distribution. During inference we average the distribution over the normalization statistics:
Connection to Batch Normalization In Sec. A we show that during training Batch Normalization (1) performs the unbiased one-sample MC estimation of a gradient of a lower bound to the marginalized log-likelihood (3). Thus, such probabilistic model corresponds to Batch Normalization during training. However, on test phase Batch Normalization uses exponentially smoothed statistics , which can be seen as a biased approximation of (3):
Straightforward MC averaging can be used for better unbiased estimation of (3), however, it is inefficient in practiÑe. Indeed, to draw one sample from the distribution over statistics (3) we need to pass an entire mini-batch through the network. So, to make MC averaging for single test object, we need to perform several forward passes with different mini-batches sampled from the training data. To address this drawback we propose Stochastic Batch Normalization.
Stochastic Batch Normalization To address memory and computational cost of straightforward MC estimation, we propose to approximate the distribution of Batch Normalization statistics with a fully-factorized parametric approximation . We parameterize and in the following way:
Such approximation works well in practice. In Sec. B we show that it accurately fits the real marginals. Since approximation no longer depends on the training data, samples for each layer can be computed without passing the entire batch through the network and it is possible to make prediction in an efficient way.
To adjust parameters we minimize the KL-divergence between distribution induced by Batch Normalization (3) and our approximation for each object:
Since belongs to the exponential family, this minimization problem is equal to moment matching and does not require gradients computation. In our implementation we simply use exponential smoothing to approximate the sufficient statistics of mean and variance distributions. It can be don for any pre-trained batch-normalized network.
|No SBN||SBN||No SBN||SBN|
|LeNet-5 MNIST||SBN||—||0.53 0.05||—|
|Deep Ensembles||0.43 0.00|
|Dropout||0.49 0.00||0.015 0.000|
|Deep Ensembles||5.18 0.00||5.23 0.00||0.154 0.002|
|Dropout||5.32 0.00||0.149 0.001|
|Deep Ensembles||3.37 0.00||0.110 0.004|
We evaluate uncertainties on MNIST and CIFAR10 datasets using convolutional architectures. In order to apply Stochastic Batch Normalization to existing architectures we only need to update parameters of our approximation (4), which does not affect the training process at all. We show that SBN improves both Dropout and Deep Ensembles techniques in terms of out-of-domain uncertainty and test Negative Log-Likelihood (NLL), and maintains the same level of accuracy
Experimental Setup We compare our method with Dropout and Deep Ensembles. Since He et al. (2015) showed that ResNet does not perform well with any Dropout layer and suffers from instability, we did not include this method into consideration for ResNet architecture. For Deep Ensembles we trained 6 models for all architectures and did not use adversarial training (as suggested by Lakshminarayanan et al. (2017)) since this technique results in lower accuracy.
Uncertainty estimation on notMNIST For this experiment we trained LeNet-5 model on MNIST and evaluated the entropy of the predictive distribution on notMNIST, which is out-of-domain data for MNIST, and plot the empirical CDF on Fig. 0(a). We also report the test set accuracy and NLL scores, the results can be seen at Tab. 1.
Uncertainty estimation on CIFAR10 To show that our method scales to deep convolutional architectures well, we perform experiments on VGG-like and ResNet architectures. We split CIFAR10 dataset into two datasets (CIFAR5) and perform a similar experiment as for MNIST and notMNIST. We trained networks on randomly chosen 5 classes and evaluated predictive uncertainty on the remaining.
We observed that Stochastic Batch Normalization improves both Dropout and Deep Ensembles in terms of out-of-domain uncertainties and NLL score on test data (from the same domain) at the same level of accuracy. However, SBN itself ends up with the more overconfident predictive distribution in comparison to baselines Dropout and Deep Ensembles.
This research is in part based on the work supported by Samsung Research, Samsung Electronics.
- Gal & Ghahramani (2015) Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. arXiv:1506.02142, 2015.
- He et al. (2015) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. CoRR, abs/1512.03385, 2015.
- Hoffman et al. (2013) Matthew D. Hoffman, David M. Blei, Chong Wang, and John Paisley. Stochastic variational inference. Journal of Machine Learning Research, 14:1303–1347, 2013.
- Ioffe & Szegedy (2015) Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. CoRR, abs/1502.03167, 2015.
- Kingma et al. (2015) Diederik P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameterization trick. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett (eds.), Advances in Neural Information Processing Systems 28, pp. 2575–2583. Curran Associates, Inc., 2015.
- Lakshminarayanan et al. (2017) Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.), Advances in Neural Information Processing Systems 30, pp. 6405–6416. Curran Associates, Inc., 2017.
- Louizos & Welling (2017) Christos Louizos and Max Welling. Multiplicative normalizing flows for variational bayesian neural networks. In Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017, pp. 2218–2227, 2017.
- MacKay (1992) David J. C. MacKay. A practical bayesian framework for backpropagation networks. Neural Comput., 4(3):448–472, May 1992. ISSN 0899-7667. doi: 10.1162/neco.19188.8.131.528.
- Molchanov et al. (2017) Dmitry Molchanov, Arsenii Ashukha, and Dmitry Vetrov. Variational dropout sparsifies deep neural networks. arXiv preprint arXiv:1701.05369, 2017.
- Srivastava et al. (2014) Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. J. Mach. Learn. Res., 15(1):1929–1958, January 2014. ISSN 1532-4435.
- Welling & Teh (2011) Max Welling and Yee Whye Teh. Bayesian learning via stochastic gradient langevin dynamics. In Lise Getoor and Tobias Scheffer (eds.), ICML, pp. 681–688. Omnipress, 2011.
Appendix A Lower bound on marginal log-likelihood
In Sec. 2 we propose the probabilistic view on Batch Normalization which models marginal likelihood . In this section we show that conventional Batch Normalization actually optimizes a lower bound on marginal log-likelihood in such probabilistic model. So the goal is to train the model parameters given training dataset . Using Maximum Likelihood approach we need to maximize the following objective :
However, the term is intractable due to the expectation over statistics. We, therefore, construct a lower bound of using the Jensen-Shannon inequality:
To use gradient-based optimization methods we need to compute gradient of w.r.t. parameters . Unfortunately, distribution over depends on and, therefore, we cannot propagate gradient through the expectation. However, we can use the definition of (3) and reparametrize expectation in terms of mini-batch distribution:
Since distribution over mini-batches does not depend on , we now can propagate the gradient through the expectation and use MC approximation for an unbiased estimation. During training Batch Normalization draws mini-batch of size and approximate the full gradient in the following way:
Note that Batch Normalization uses the same mini-batch to calculate statistics as for gradient estimation. Taking an expectation over mini-batch , we can actually see that such procedure performs an unbiased estimation of :
So Batch Normalization produces an unbiased gradient estimation of during training and can be seen as an approximation for inference in proposed probabilistic model.
Appendix B Statistics distribution approximation
For computational and memory efficiency we propose the following approximation for the real distribution over the batch statistics, induced by Batch Normalization:
According to our observation, the real distributions are unimodal Fig 2. Also the Central Limit Theorem implies that the means converge in distributions to Gaussians, therefore we model this distribution using a fully-factorized Gaussian. While the common choice for the variance is Gamma distribution, we choose the log-normal distribution, as it allows for a more tractable moment-matching. Also as we show in the plots, the log-normal distribution fits the data well.
To verify the right choice of parametric family we estimate an empirical marginal distributions over and for LeNet-5 architecture on MNIST dataset. To sample statistics from the real distribution we pass different mini-batches from training data through the network. We use Kernel Density Estimation to plot the empirical distribution. The results for convolutional and fully-connected layers of LeNet-5 can be seen at Fig. 2 and 3. It can be seen that the approximation (7) fits the real marginal distributions over very accurately.