Hierarchical Gaussian Process Priors for Bayesian Neural Network Weights

Hierarchical Gaussian Process Priors for Bayesian Neural Network Weights


Probabilistic neural networks are typically modeled with independent weight priors, which do not capture weight correlations in the prior and do not provide a parsimonious interface to express properties in function space. A desirable class of priors would represent weights compactly, capture correlations between weights, facilitate calibrated reasoning about uncertainty, and allow inclusion of prior knowledge about the function space such as periodicity or dependence on contexts such as inputs. To this end, this paper introduces two innovations: (i) a Gaussian process-based hierarchical model for network weights based on unit embeddings that can flexibly encode correlated weight structures, and (ii) input-dependent versions of these weight priors that can provide convenient ways to regularize the function space through the use of kernels defined on contextual inputs. We show these models provide desirable test-time uncertainty estimates on out-of-distribution data, demonstrate cases of modeling inductive biases for neural networks with kernels which help both interpolation and extrapolation from training data, and demonstrate competitive predictive performance on an active learning benchmark.

1 Introduction

Bayesian neural networks (BNNs) (see e.g. mackay:92; neal1992bayesian; ghahramani2016history) are one of the research frontiers on combining Bayesian inference and deep learning, potentially offering flexible modelling power with calibrated predictive performance. In essence, applying probabilistic inference to neural networks allows all plausible network parameters, not just the most likely, to be used for predictions. Despite the strong interest in the community for the exploration of BNNs, there remain unanswered questions: (i) how can we model neural network functions to encourage behaviors such as interpolation between signals and extrapolation from data in meaningful ways, for instance by encoding prior knowledge, or how to specify priors which facilitate uncertainty quantification, and (ii) many scalable approximate inference methods are not rich enough to capture complicated posterior correlations in large networks, resulting in undesirable predictive performance at test time.

This paper attempts to tackle some of the aforementioned limitations. We propose inherently correlated weight priors by utilizing unit-level latent variables to obtain a compact parameterization of neural network weights and combine them with ideas from the Gaussian process (GP) literature to induce a hierarchical GP prior over weights. This prior flexibly models the correlations between weights in a layer and across layers. We explore the use of product kernels to implement input-dependence as a variation of the proposed prior, yielding models that have per-datapoint priors which facilitate inclusion of prior knowledge through kernels while maintaining their weight structure. A structured variational inference approach is employed that side-steps the need to do inference in the weight space whilst retaining weight correlations in the approximate posterior. The proposed priors and approximate inference scheme are demonstrated to exhibit beneficial properties for tasks such as generalization, uncertainty quantification, and active learning.

The paper is organized as follows: in Section 2 we review hierarchical modeling for BNNs based on unit-variables. In Section 3 we introduce the global and local weight models and their applications to neural networks. Efficient inference algorithms for both models are presented in Section 4, followed by a suite of experiment to validate their performance in Section 5. We review related work in Section 6.

2 Meta-representing weights and networks

Our work builds on and expands the class of hierarchical neural network models based on the concept of latent variables associated with units in a network as proposed in (karaletsos2018probabilistic). In that model, each unit (visible or hidden) of the -th layer of the network has a corresponding latent hierarchical variable , of dimensions , where denotes the index of the unit in a layer. Note that these latent variables do not describe the activation of units, but rather constitute latent features associated with a unit.

The design of these latent variables is judiciously chosen to construct the weights in the network as follows: a weight in the -th layer, is generated by using the concatenation of latent variable ’s of the -th input unit and the -th output unit as inputs of a mapping function .

We can summarize this relationship by introducing a set of weight encodings , one for each individual weight in the network , which can be deterministically constructed from the collection of unit latent variable samples z by concatenating them correctly according to network architecture. The probabilistic description of the relationship between the weight codes (summarizing the structured latent variables) and the weights is:

where denotes a visible or hidden layer and is the number of units in that layer, is the total number of layers in the network, and denotes all the weights in this network.

In karaletsos2018probabilistic, a small parametric neural network regression model (conceptually a structured hyper-network) is chosen to map the latent variables to the weights, using either a Gaussian noise model or an implicit noise model: , where is a random variate. We will call this network a meta mapping. Note that given the collection of sampled unit variables and the derived codes , the weights (or theirs mean and variance) can be obtained efficiently in parallel. A prior over latent variables completes the model specification, . The joint density of the resulting hierarchical BNN is then specified as follows,

with denoting the number of observation tuples .

Variational inference was employed in prior work to infer (and implicitly), and to obtain a point estimate of , as a by-product of optimising the variational lower bound. Critically, in this representation weights are only implicitly parametrized through the use of these latent variables, which transforms inference on weights into inference of the much smaller collection of latent unit variables.

The central motivations for our adoption of this parameterization are two-fold. First, the number of visible and hidden units in a neural network is typically much smaller than the number of weights. For example, for the -th weight layer, there are weights compared to associated latent variables. This encourages the development of models and inference schemes that can work in the lower-dimensional hierarchical latent space directly without the need to model individual weights privately. This structured representation per weight allows powerful hierarchical models to be used without requiring high dimensional parametrizations (i.e. hypernetworks for the entire weight tensor). Specifically, a GP-LVM prior lawrence2004gaussian over all network weights appears infeasible without such an encoding of structure. Second, the compact latent space representation facilitates attempts at building fine-grained control into weight priors, such as structured prior knowledge as we will see in the following sections.

Figure 1: Graphical depiction of BNNs with hierarchical GP-MetaPriors and ones with input-dependent variables.

3 Hierarchical GP-Priors For BNN Weights

Notice that in Section 2, the meta mapping from the hierarchical latent variables to the weights is a parametric non-linear function, specified by a neural network. We replace the parametric neural network by a probabilistic functional mapping and place a nonparametric Gaussian process prior over this function. That is,

where we have assumed a zero-mean GP, is a covariance function and is a small set of hyper-parameters, and a homoscedastic1 Gaussian noise model with variance .

The effect is that the latent function introduces correlations for the individual weight predictions,

Notably, while the number of latent variables and weights can be large, the input dimension to the GP mapping is only , where is the dimensionality of each latent variable . The GP mapping effectively performs one-dimensional regression from latent variables to individual weights while capturing their correlations. We will refer to this mapping as a GP-MetaPrior (MetaGP). We define the following kernel at the example of two weights in the network,

In this section and what follows, we will use the popular exponentiated quadratic (RBF) kernel with ARD lengthscales, , where are the lengthscales and is the kernel variance.

BNNs with GP-MetaPriors are then specified by the following joint density over all variables:

We show prior samples from this model in Fig. 2 by the following procedure: for a sample we instantiate the covariance matrix by constructing weight codes and applying the kernel function . We draw weights from the GP by sampling the Normal distribution , where , with denoting the number of all parameters in the network. We then generate BNN function samples given the sampled weights with homoscedastic noise on the outputs. We highlight two properties of this model: First, as a hierarchical model, given a sample of the latent variables z, the model instantiates a prior over weights from which we can further sample functions and thus encodes two levels of uncertainty over functions. Second, we demonstrate that changing the length-scale parameter of the mapping-kernel , again even given fixed samples z, leads to vastly differentiated function samples, showing the compact degree of control the mapping parameters have over the function space being modeled. In this case, the length-scale appears to control the variance over the function space, which matches an intuitive interpretation over the kernel parameter.

An important task is marginalization of the latent quantities given data to perform posterior inference, which we discuss in Section 4.1.

Figure 2: MetaGP Prior Samples: We show function samples generated from a [1,20,10,1] unit BNN with ReLUs with meta-GP prior. We draw one sample per unit to instantiate the weight prior and subsequently draw 40 function samples (individually colored) from the BNN by drawing from the conditional prior and regressing . (left) We show samples drawn when the global RBF kernel for the GP has a length scale set to a small value. (middle) we keep the same samples z and only change the length scale to be larger and visualize the functions induced by the BNN (right) We visualize the weight kernel given the latent variables . Other samples of z would induce different weight covariance matrices. Overall this figure shows that even given z, the proposed prior models a wide range of functions which have controllable properties based on the parameters of the kernel.

3.1 Input-kernels for modulating function priors

While the hierarchical latent variables and meta mappings introduce non-trivial coupling between the weights a priori, they are inherently global. That is, a function drawn from the model, represented by a set of weights, does not take into account the inputs at which the function will be evaluated. In this section, we will describe modifications to our weight prior which allow conditional weight models on inputs.

To this end, we introduce the input variable into the weight codes , which we utilize to yield input-conditional weight models through the use of product kernels. Concretely, we introduce a new input kernel which multiplied with the global weight kernel gives the kernel for the meta mapping,

where is the kernel defined over latent-variable weight codes from Section 3, is an auxiliary kernel modeling input-dependence on , and is shorthand for . This factorization over kernels represents an assumption of separable influence on functions by latent variables and inputs. The weight priors are now also local to each data point, in a similar vein to how functions are drawn from a GP, while still instantiating an explicit, weight-based model.

We demonstrate the effects of utilizing the auxiliary kernel in this factorized fashion by visualizing prior function samples from a BNN with this local prior when changing kernel parameters in Fig. 3, exemplifying the proposed model’s ability to encode controlled periodic structure into BNN weight priors before seeing any data. As performed in Section 3, we sample from the GP to instantiate weights, but in the case of the local model we instantiate the covariance matrix . We will discuss the handling of this conceptually large object in Section 4.2.

To scale this to large inputs, we learn transformations of inputs for the conditional weight model , for a learned mapping and a nonlinearity and generalize weight codes to , with describing their collection. In detail, each auxiliary input is obtained via a (potentially nonlinear) transformation applied to an input: , where , and and are the dimensionality of and , respectively, and is an arbitrary transformation. We may also layer these transformations in general. We typically set so this transformation could be thought of as a dimensionality reduction operation. For low dimensional inputs, we set .

Including these transformations yields the weight model , that is, the input dimension of the meta mapping is now . Additionally, we also place a prior over the linear transformation: . We will refer to this mapping as a Local GP-MetaPrior (MetaGP-local).

What effects should we expect from such a modulation? Consider the use of an exponentiated quadratic kernel: we would expect data which lies far away from training data to receive small kernel values from . This, in turn, would modulate the entire kernel for that data point to small values, leading to a weight model that reverts increasingly to the prior. We would expect such a model to help with modeling uncertainty by resetting weights to uninformative distributions away from training data. One may also want to use this mechanism to express inductive biases about the function space, such as adding structure to the weight prior that can be captured with a kernel. This is an appealing avenue, as multiple useful kernels have been found in the GP literature that allow modelers to describe relationships between data, but have previously not been accessible to neural network modelers. We consider this a novel form of functional regularization through the weight prior, which can imbue the entire network with structure that will constrain its function space.

Figure 3: Local MetaGP Samples: We remind the reader that the input-dependent weight prior has a factorized kernel structure , and we wish to demonstrate the effect of each kernel separately in terms of its effects on the induced function prior for the neural network. We are given the same samples as in Fig. 2 and also keep the two kernel parameter choices for , while varying only the period parameter for an auxiliary periodic kernel. Left: We show function samples using a small period of 0.1 and a period of 0.3 in combination with the kernel with length-scale 0.8. We can see, that while the functions are still relatively flat, the auxiliary kernel induces weight priors which lead to periodic function samples consistent with the auxiliary kernel setting. Right: Similarly, when performing the same protocol for the with the larger lengthscale, we again observe periodic functions consistent with the set period (although we only apply the periodic kernel for the weight priors for the BNN), but see that the functions sampled have more variance, consistent with the larger length-scale of the weight-kernel . Note that while the functions exhibit periodic structure, they have non-periodic global structure as well, as they also draw information from and the priors are merely modulated by the auxiliary kernel. We thus see that our prior structure successfully induces function priors which naturally inherit properties we can express as kernel functions, but keep rich expressivity as weight based models.

BNNs with Local GP-MetaPriors specify neural networks with individual weight priors per datapoint (also see Graphical Model in Fig. 1):

Inference and learning are modified accordingly as explained in Section 4.2.

4 Inference and learning using stochastic structured variational inference

Performing inference is challenging due to the non-linearity of the neural network and the need to infer an entire latent function . In Section 4.1 we address these problems for MetaGP, deriving a structured variational inference scheme that makes use of innovations from inducing point GP approximation literature (titsias2009variational; hensman2013gaussian; quinonero2005unifying; matthews2016sparse; bui2017unifying) and previous work on inferring meta-representations (karaletsos2018probabilistic). In Section 4.2 we will highlight the modifications necessary to make this inference strategy work for MetaGP-local.

4.1 Inference for the global model

A common strategy for variational inference in GPs is the utilization of inducing points, which entails the construction of learned inputs to the function and corresponding function values which jointly take the place of representative data points. The inducing inputs and outputs, , will be used to parameterize the approximation.

We first partition the space of inputs (or weight codes) to the function into a finite set of variables called inducing inputs where and the remaining inputs, . The function is partitioned identically, , where . We can then rewrite the GP prior as follows, .2 In particular, a variational approximation is judiciously chosen to mirror the form of the joint density:


where the variational distribution over is made to explicitly depend on remaining variables through the conditional prior, and is chosen to be a diagonal (mean-field) Gaussian density, , and is chosen to be a correlated multivariate Gaussian, . This approximation allows convenient cancellations yielding a tractable variational lower bound as follows,

where the last expectation has been approximated by simple Monte Carlo with the reparameterization trick, i.e.  (salimans2013fixed; kingma2013auto; titsias2014doubly). We will next discuss how to approximate the expectation . Note that we split f into and , and that we can integrate out exactly to give, ,

where , , . At this point, we can either (i) sample from , or (ii) integrate out analytically. Opting for the second approach gives , the former just omits the second covariance term and uses a sample for the predictive mean instead of .

In contrast to GP regression and classification in which the likelihood term is factorized point-wise w.r.t. the parameters and thus their expectations only involve a low dimensional integral, we have to integrate out which for GPs entails inversion of the matrix K (which is when we don’t sample or the full term above). This is feasible for small neural networks with up to a few thousand weights, but becomes intractable for more general architectures. In order to scale to larger networks, we introduce a diagonal approximation, which given a sample looks as follows, . Whilst the diagonal approximation above might look poor at first glance, it is conditioned on a sample of the latent variables and thus the weights’ correlations induced by the hierarchical unit-structure are retained after integrating out . Such correlations are illustrated in Fig. 5, showing the marginal and conditional covariance structures for the weights of a small neural network, separated into diagonal and full covariance models. We also provide a qualitative and quantitative analysis of performance of different approximations to in the appendix, including the diagonal approximation presented here, and show that not only is this approximation fast but also that it performs competitively with full covariance models. Finally, the expected log-likelihood is approximated by with samples 3. The final lower bound is then optimized to obtain the variational parameterers of , , and estimates for the noise in the meta-GP model, the kernel hyper-parameters and the inducing inputs.

Figure 4: Predictive performance of various methods on a four-way classification problem. We compare the proposed approaches (MetaGP, MetaGP with an input-dependent RBF kernel and periodic kernel) to BNN with MFVI and HMC, DKL and MetaNN. Best viewed in colour. The background color shows the entropy of the predictive distribution. The contours show the 0.7 equiprobability contours. The bottom plots are the zoom-out version of the corresponding top plots, showing the predictive entropy further from the training points.

4.2 Inference for the local model

The main difference in the local model is the dependence of weights on inputs. To handle inducing point kernels over both weight codes and inputs, we introduce inducing inputs where for . We then concatenate the dimensions of in Section 4.1 with the new inducing inputs to form the new inputs . The set of inputs now have dimensions . The fully instantiated covariance matrix would take the shape . As this kernel has Kronecker structure one could now consider using inference techniques such as in (flaxman2015fast). However, the tractability of the global kernel remains an issue even in this case. As such, we elect to inherit the diagonal approximation from Section 4.1 and apply it to the joint kernel, yielding an object of dimension . The lower bound computation in Section 4.1 can thus be reused but with and being input-dependent4. We can handle large datasets by using inducing point kernels, which permit inference using minibatches. Another difference is the potential existence of the mapping in the model, which we tackle by introducing a variational distribution . We can estimate the evidence lower bound by also drawing unbiased samples from this and jointly optimizing its parameters with the rest of the variational parameters. The overall computational complexity with data-subsampling in this section and Section 4.1 is .

5 Experiments

In this section, we evaluate the proposed priors and inference scheme on several regression and classification datasets. These were implemented using PyTorch (paszke2017automatic) and the code will be available upon acceptance. Additional results are included in the appendices. We use inducing points for all experiments in this section. All experiments were run on a Macbook pro.

Figure 5: Marginal and conditional covariance structures over weights in a 1x50x1 BNN. Sampling from the posterior of the hierarchical model reveals that even a diagonal GP approximation can capture off-diagonal correlations induced through unit correlations. Also note the off-diagonal bands in the marginal plots above, which indicate the correlation structures induced by the latent variables of the hidden units connecting the layers.We remove the diagonal in the marginal plots for clarity.
Figure 6: Predictive performance of various methods on a sinusoidal dataset. We also provide a quantitative comparison in Table 1.

5.1 Synthetic classification example

We first illustrate the performance of the proposed model on a classification example. We generate a dataset of 100 data points and four classes, and use a BNN with one hidden layer of 50 hidden units with ReLU non-linearities, and two dimensional latent variables z. Figure 4 shows the predictive performance of the proposed priors and various alternatives, including BNN (with unit Normal priors on weights) with mean field Gaussian variational approximation (MFVI) (blundell2015weight) and Hamiltonian Monte Carlo (HMC) (neal1992bayesian), variational deep kernel learning (DKL) wilson2016stochastic and MetaNN (karaletsos2018probabilistic). We highlight that MetaGP-local with RBF kernel gives uncertainty estimates that are reminiscent to that of a GP model in that the predictions express “I don’t know” away from the training data, despite being a neural network under the hood. Following bradshaw2017adversarial, we also show the uncertainty for data further from the training instances. MetaGP-local(RBF) remains uncertain, as expected, for these points while MFVI and DKL produce arguably overconfident predictions.

5.2 Inductive Biases For Neural Networks With Input-Dependent Kernels

We explore the utility of the input-dependent prior towards modeling inductive biases for neural networks and evaluate predictive performance on a regression example. In particular, we generate 100 training points from a synthetic sinusoidal function and create two test sets that contain in-sample inputs and out-of-sample inputs, respectively. We test an array of models and inference methods, including BNN (with unit Normal priors on weights) with MFVI and HMC, GPs with diverse kernel functions, DKL, MetaGP and local-MetaGP with input dependence given the same kernels as the GPs. We use RBF and periodic kernels (mackay1998introduction) for weight modulation and the pure GP in this example. Figure 6 summarizes the results. Note that the periodic kernel allows the BNN to discover and encode periodicity in its weights, leading to long-range confident predictions compared to that of the RBF kernel and significantly better extrapolation than BNNs with independent weight priors can obtain given the amount of training data, even when running HMC instead of VI.

We evaluate the quantitative utility of input-dependence and inductive biases on two test sets that contain in-sample inputs (between the training data) and out-of-sample inputs (outside the training range), respectively. We report the performance of all methods in Table 1. The performance is measured by the root mean squared error (RMSE) and the negative log-likelihood (NLL) on the test set, and we explicitly evaluate separately for extrapolation and interpolation. In this example, the local MetaGP model is comparable to GP regression with a periodic kernel and superior to other methods, demonstrating good RMSE and NLL on both in-distribution and out-of-distribution examples.

Interpolation Extrapolation
BNN-MFVI 0.17 -0.04 3.51 88.12
BNN-HMC 0.12 -0.69 4.34 10.98
Exact GP-RBF 0.11 -0.81 0.55 0.75
Exact GP-Periodic 0.11 -0.80 0.11 -0.83
DKL 0.12 -0.72 0.76 3.26
MetaGP 0.24 0.08 2.59 5.86
MetaGP-Local[RBF] 0.11 -0.80 0.74 1.50
MetaGP-Local[Periodic] 0.11 -0.76 0.12 -0.69
Table 1: Average test error and negative log-likelihood for the sinusoid example in Fig. 6, averaged over five runs. Lower is better.
Figure 7: The CDFs of predictive entropies on in-distribution and out-of-distribution test sets for various methods [Left] and the predictive class probability for representative samples from out-of-distribution test sets [Right].

5.3 Input Dependent Neural Networks For Uncertainty Quantification

Motivated by the performance of the proposed MetaGP-local model in the synthetic examples in Figure 6, this section tests the ability of this model class to produce calibrated predictive uncertainty to out-of-distribution samples. That is, for test samples that do not come from the same training distribution, a robust and well-calibrated model should produce uncertain predictive distribution and thus high predictive entropy. Such a model could find applications in safety-critical tasks or in an area where detecting unfamiliar inputs is crucial such as active learning or reinforcement learning. In this experiment, we train a BNN classifier with one hidden layer of 100 rectified linear units on the MNIST dataset, with MetaGP-local-RBF only applied to the last layer of the network. The dimensions of the latent variables and the auxiliary inputs are both 2, with auxiliary inputs given by transforming MNIST images using a jointly learned linear projection . After training on MNIST, we compute the entropy of the predictions on various test sets, including notMNIST, fashionMNIST, Kuzushiji-MNIST, and uniform and Gaussian noise inputs. Following (lakshminarayanan2017simple; louizos2017multiplicative), the CDFs of the predictive entropies for various methods are shown in Fig. 7. A calibrated classifier should give a CDF that bends towards the top-left corner of the plot for in-distribution examples and, vice versa, towards the bottom-right corner of the plot for out-of-distribution inputs. In most out-of-distribution sets considered, except Gaussian random noise, MetaGP and MetaGP-local demonstrate superior performance to all comparators, including DKL. Notably, MAP estimation, often deployed in practice, tends to give wildly poor uncertainty estimates on out-of-distribution samples. We illustrate this behaviour and that of other methods on representative inputs of the Kuzushiji-MNIST dataset in Figure 7 and on MNIST digits in the appendix.

5.4 Active learning

Figure 8: Active learning with BNNs using mean-field Gaussian variational inference [MFVI] and a meta-GP hierarchical prior [MetaGP] on several UCI regression datasets. Each trace shows the root mean squared error (RMSE) averaged across 40 runs.

We next stress-test the performance of the proposed model in a pool-based active learning setting for real-valued regression, where limited training data is provided initially and the target is to sequentially select points from a pool set to add to the training set. The criterion to select the next best point from the pool set is based on the entropy of the predictive distribution, i.e. we pick one with the highest entropy. Note that this selection procedure can be interpreted as selecting points that maximally reduce the posterior entropy of the network parameters houlsby2011bayesian. Four UCI regression datasets were considered, where each with 40 random train/test/pool splits. For each split, the initial train set has 20 data points, the test set has 100 data points, and the remaining points are used for the pool set, similar to the active learning set-up in hernandez-lobato2015prob. We compare the performance of the proposed model and inference scheme to that of Gaussian mean-field variational inference and show the average results in Figure 8. Across all runs, we observe that active learning is superior to random selection and more crucially using the proposed model and inference scheme seems to yield comparable or better predictive errors with a similar number of queries.

6 Related work

There is a long history of research on developing (approximate) Bayesian inference methods for BNNs, i.e. in (neal1992bayesian; neal2012bayesian; ghahramani2016history). Our work differs in that the model employs a hierarchical prior, and inference is done in a lower-dimensional latent space instead of the weight space. The variational approximation is chosen such that the marginal distribution over the weights is non-Gaussian and the correlations between weights are retained, in contrast to the popular mean-field Gaussian approximation. Imposing structure over the weights with a carefully chosen prior has been observed to improve predictive performance (ghosh2018structured; neal2012bayesian; blundell2015weight), but it has remained elusive how to express prior knowledge or handle interpolation or extrapolation in such models. Modern deep Bayesian learning approaches often involve fusing neural networks and GPs, such as in deep kernel learning (wilson2016stochastic), which layers a GP on top of a neural network feature extractor. Another notable example is (pearce2019expressive), which blends kernels and activation functions to induce desired properties through architectural choices, but is not expressing these assumptions as a weight prior. The functional regularization approach introduced in (sun2019functional) shares some of the motivations with our paper, but implements it very differently by explicitly instantiating a GP and performing a complex training scheme to learn neural networks that match that GP. Asymptotically, they match the GP, while in our model (i) the properties we care about are already built into the weight prior allowing direct training on a dataset without the involved minimax approach, and (ii) our posterior can depart from that restrictive prior as it fundamentally only guides a weight based model, i.e. by learning posterior kernel parameters for to eliminate its influence on (such as wide lengthscales).
Another related theme is hyper-networks, the core idea of which is to generate network parameters using another network (see e.g. ha2016hypernetworks; stanley2009hypercube). Our model resembles a GP-LVM (lawrence2004gaussian) hyper-GP, with a key structural assumption of node latent variables as introduced in (karaletsos2018probabilistic) to enable compact prediction per weight instead of per weight tensor.

7 Summary

We proposed a GP-based hierarchical prior over neural network weights, and a modification that permits input-dependent weight priors, along with an effective approximate inference strategy. We demonstrated utility of these models for interpolation, extrapolation, uncertainty quantification and active learning benchmarks, outperforming strong baselines. We plan to evaluate the performance of the model on more challenging decision making tasks.


Appendix A Additional Background on Bayesian neural networks and variational inference

Consider a training set comprising of input-output pairs, , and a neural network parameterized by weights and biases, , that describes the distribution over an output given an input , . We follow a Bayesian approach by placing a prior distribution over the network parameters, , and obtaining the posterior distribution , which involves calulation of the marginal likelihood . However, obtaining and exactly is intractable when is large or when the network is large and as such, approximation methods are often required. In particular, mean-field Gaussian variational inference (MFVI) has recently become a method of choice for approximate inference for Bayesian neural networks due to its simplicity and the recently popularized reparameterization trick [salimans2013fixed, kingma2013auto, titsias2014doubly, blundell2015weight]. MFVI sidesteps the intractability by positing a diagonal Gaussian approximation and optimising an approximate lower bound to the marginal likelihood , where and , i.e.  is a sample from . Note that the mean-field variational Gaussian approximation with a standard normal prior, presented in is often outperformed by point estimation in certain settings [trippe2018overpruning]. Despite being practical and able to give reasonable uncertainty estimates, improving MFVI is still an active research area, and the main focuses of which are (i) improving the reparameterization gradient estimator to enable faster convergence [miller2017reducing, wu2018fixing], (ii) replacing the typical standard Normal prior, by a structured prior that better models the structures present in the weight a-priori [ghosh2018structured, neal2012bayesian, blundell2015weight], and (iii) using structured variational approximations that can potentially capture weight correlations in the posterior [louizos2016structured, zhang2017noisy]. This paper builds on the two latter themes and proposes a hierarchical model for the prior and a structured variational scheme that explicitly model and infer weight structures.

Appendix B Extra experimental results

b.1 An empirical evaluation of various approximations for

In this section, we analyze the impact of different approximations to the covariance matrix of :

If we use the exact, fully correlated Gaussian distribution above, it is necessary to sample from this distribution to evaluate the lower bound. This step costs where is the number of parameters in the network.

The complexity can be greatly improved by making a diagonal approximation to as follows,

Sampling from this distribution can be done in where M is the number of pseudo-points.

This can be further approximated by assuming a diagonal covariance matrix,

The variational bound can then be evaluated by drawing samples from as in the above approximation, or by drawing activity samples by employing the local reparameterization trick [kingma2015variational].

We evaluate the performance of using the exact and approximate conditional distributions above in a range of toy regression and classification, and show representative results in Figs. 10 and 9. We note that the diagonal approximation is fast and gives qualitatively similar performance compared to more structured approximation or the exact case, in both cases where there is a single GP for all weights in the network and there is multiple GPs, one for each weight layer in the network.


Figure 9: An evaluation of the covariance matrix approximations in a toy regression example. Top: objective function during training vs epoch/time. Bottom: Predictions after training using one of the approximations discussed in the text. Global: there is one GP for all weights in the network. Layer: there are multiple GPs, one for each weight layer in the network. Note that we are not using the auxiliary kernel here. Best viewed in colour.
Figure 10: An evaluation of the covariance matrix approximations in a toy classification example. Top: objective function during training vs epoch/time. Bottom: Predictions after training using one of the approximations discussed in the text. Global: there is one GP for all weights in the network. Layer: there are multiple GPs, one for each weight layer in the network. Note that we are not using the auxiliary kernel here. Best viewed in colour.

b.2 Results on a synthetic regression example

In this section, we demonstrated the performance of the proposed priors on a 1D test function, as used in [louizos2019fnp]. We compare to BNN with independent Gaussian priors and a mean-field Gaussian variational approximation, and MetaNN [karaletsos2018probabilistic]. The training points, and predictive mean and error bars are shown in Fig. 11.

Figure 11: Predictive performance of various methods on a 1D test function. We compare the proposed approaches (MetaGP and MetaGP with an input-dependent kernel) to BNN-MFVI and MetaNN. Best viewed in colour.

b.3 Robustness in various data regimes for a toy regression problem

In this experiment, we evaluate the qualitative performance of various methods, including MFVI, MetaNN, MetaGP and MetaGP with local, input-dependent kernel, on a toy regression problem, in different data regimes. In particular, we considers 10, 20, and 50 training points respectively, and plot the predictions in Fig. 12. MetaGP demonstrates consistent performance across all data regimes, comparable to that of MetaNN. The input-dependent kernel helps the performance further in the out-of-distribution area.

Figure 12: Performance of mean-field variational inference, MetaNN with variational inference and MetaGP with variational inference on a toy regression problem with various number of training points. Best viewed in colour.

b.4 Robustness of MetaGP with network architectures

In this experiment, we compare the performance of MetaGP for various numbers of hidden units (20, 50 and 100) and two activation functions (Tanh and ReLU) on a toy regression problem. The observation noise is fixed in this experiment. We observe that the performance of the models is in general consistent across different activation functions and numbers of hidden units. We show the results in Fig. 13.

Figure 13: Performance of MetaGP on a toy regression problem, with various numbers of hidden units and different activation functions. Best viewed in colour.

b.5 Effect of input-dependent kernels

To understand the impact of the auxiliary kernel to the prediction, we use a model trained on the sinusoid dataset, as shown in the main text, and vary the period hyper-parameter in the kernel whilst keeping other hyper-parameters and variational parameters fixed. The predictions for a few hyperparameters are shown in Fig. 14. We note the variation/period in the data is captured by weight modulation, governed by the input-dependent kernel. Changing the period hyperparameter affects how fast or slow the weights are changing wrt the input.

Figure 14: We first train a model with an input-dependent kernel on a sinusoid data set (top left) and then vary the period hyperparameter of the input-dependent kernel whilst keeping other hyperparameters and variational parametes fixed (others). Best viewed in colour.

b.6 MNIST experiment: full figures

In this section, we include the full figures of the MNIST out-of-distribution uncertainty experiment, as well as additional results using deep kernel learning [wilson2016deep]. In particular, we employ the same network architecture with the last layer being replaced by multiple independent GPs, one for each class (output dimension). As exact inference is intractable, variational inference based on inducing points is employed – we used 50 inducing points for each output. The full results of all models/methods considered are shown in Fig. 15. For clarify, the results of deep kernel learning and MetaGP are shown in LABEL:fig:mnist_entropy_3. MetaGP with the input-dependent kernel shows good performance, outperforming deep kernel learning in all cases. In addition, we include the full figures for the predictive distributions on representative test examples in Figs. 17 and 16.

Figure 15: Full results of the MNIST out-of-distribution uncertainty experiment. Best viewed in colour.
Figure 16: Predictive distribution for representative MNIST test examples by various methods. Best viewed in colour.
Figure 17: Predictive distribution for representative KMNIST test examples by various methods. Best viewed in colour.

b.7 A toy active learning problem

In this section, we provide a visualisation of the predictive performance of different methods in an active learning setting. Please see Fig. 18 and the associated caption.

Figure 18: Active learning with BNNs using maximum a posteriori estimation [BNN-MAP], mean-field Gaussian variational inference [BNN-MFVI] and a meta-GP hierarchical prior [BNN-MetaGP] on a toy multi-class classification task. For each plot, the filled circle markers are the current training points, with different colours illustrating different classes. The shaded crosses are the examples in the pool set, one of which we wish to pick and evaluate to be included in the training set. The unfilled circle markers are the examples from the pool set selected at a step. The objective function for selecting points from the pool set is the entropy of the predictive probability. Best viewed in colour.

b.8 Applications to multi-task learning

We further investigate using the proposed model for multi-task learning. In particular, the latent variable and corresponding hyper-parameters and variational parameters can be shared across different tasks whilst the meta mapping and the input-dependent kernel are private to each individual task. We first train the model on four regression tasks, each corresponds to a sinusoid of a particular frequency. At test time, a novel test set is shown to the model. The hyper-parameters of the input kernel and variational parameters corresponding to this new test set are optimised while other hyper-parameters and the latent variables are kept fixed. We evaluate the performance of the model on the novel test sets to see how the latent variables can be reused and shared across tasks to facilitate fast adaptation to new settings. The performance of the model on the tasks used for training and new tasks at test time is shown in Fig. 19. This result demonstrates the ability of the model trained with multiple similarly related tasks to faithfully and quickly adapt to new settings.

Figure 19: Training on multiple related tasks and adaptation to novel tasks at test time. In this case, the latent variables (as well as weight code hyperparameters) are shared across tasks while each individual has its own input-dependent kernel. At test time, only the private parameters for the new task are re-initialised and optimised. Best viewed in colour.


  1. Here, we present a homoscedastic noise model for the weights, but the model is readily adaptable to a heteroscedastic noise model which we omit for clarity.
  2. The conditioning on and in and is made implicit here and in the rest of this paper.
  3. We can also use the local reparameterization trick (kingma2015variational) to reduce variance.
  4. Specifically, where , , .
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description