Towards Verified Stochastic Variational Inference for Probabilistic Programs
Probabilistic programming is the idea of writing models from statistics and machine learning using program notations and reasoning about these models using generic inference engines. Recently its combination with deep learning has been explored intensely, which led to the development of so called deep probabilistic programming languages, such as Pyro, Edward and ProbTorch. At the core of this development lie inference engines based on stochastic variational inference algorithms. When asked to find information about the posterior distribution of a model written in such a language, these algorithms convert this posterior-inference query into an optimisation problem and solve it approximately by a form of gradient ascent or descent. In this paper, we analyse one of the most fundamental and versatile variational inference algorithms, called score estimator or REINFORCE, using tools from denotational semantics and program analysis. We formally express what this algorithm does on models denoted by programs, and expose implicit assumptions made by the algorithm on the models. The violation of these assumptions may lead to an undefined optimisation objective or the loss of convergence guarantee of the optimisation process. We then describe rules for proving these assumptions, which can be automated by static program analyses. Some of our rules use nontrivial facts from continuous mathematics, and let us replace requirements about integrals in the assumptions, such as integrability of functions defined in terms of programs’ denotations, by conditions involving differentiation or boundedness, which are much easier to prove automatically (and manually). Following our general methodology, we have developed a static program analysis for the Pyro programming language that aims at discharging the assumption about what we call model-guide support match. Our analysis is applied to the eight representative model-guide pairs from the Pyro webpage, which include sophisticated neural network models such as AIR. It finds a bug in two of these cases, and shows that the assumptions are met in the others.
Probabilistic programming refers to the idea of writing models from statistics and machine learning using program notations and reasoning about these models using generic inference engines. It has been the subject of active research in machine learning and programming languages, because of its potential for enabling scientists and engineers to design and explore sophisticated models easily; when using these languages, they no longer have to worry about developing custom inference engines for their models, a highly-nontrivial task requiring expertise in statistics and machine learning. Several practical probabilistic programming languages now exist, and are used for a wide range of applications (Carpenter et al., 2017; Minka et al., 2014; Gordon et al., 2014; Goodman et al., 2008; Mansinghka et al., 2014; Wood et al., 2014; Tolpin et al., 2016; Narayanan et al., 2016; Gehr et al., 2016).
In this paper, we consider inference engines that lie at the core of so called deep probabilistic programming languages, such as Pyro (Bingham et al., 2019), Edward (Tran et al., 2016, 2018) and ProbTorch (Siddharth et al., 2017). These languages let their users freely mix deep neural networks with constructs from probabilistic programming, in particular, those for writing Bayesian probabilistic models. In so doing, they facilitate the development of probabilistic deep-network models that may address the problem of measuring the uncertainty in current non-Bayesian deep-network models; a non-Bayesian model may predict that the price of energy goes up and that of a house goes down, but it cannot express, for instance, that the model is very confident with the first prediction but not the second.
The primary inference engines for these deep probabilistic programming languages implement stochastic variational inference algorithms. Converting inference problems into optimisation problems is the high-level idea of these algorithms.111The inference problems in their original forms involve summation/integration/counting problems, which are typically more difficult than optimisation problems. The variational-inference algorithms can convert the former problems to the latter ones because they look for approximate, not exact, answers to the former. When asked to find information about the posterior distribution of a model written in such a language, these algorithms convert this question to an optimisation problem and solve the problem approximately by performing a gradient descent or ascent on the optimisation objective. The algorithms work smoothly with gradient-based parameter-learning algorithms for deep neural networks, which is why they form the backbone for deep probabilistic programming languages.
In this paper, we analyse one of the most fundamental and versatile variational inference algorithms, called score estimator or REINFORCE (Williams, 1992; Paisley et al., 2012; Wingate and Weber, 2013; Ranganath et al., 2014), using tools from denotational semantics and program analysis (Cousot and Cousot, 1977, 1979, 1992). We formally express what this algorithm does on models denoted by probabilistic programs, and expose implicit assumptions made by the algorithm on the models. The violation of these assumptions can lead to undefined optimisation objective or the loss of convergence guarantee of the optimisation process. We then describe rules for proving these assumptions, which can be automated by static program analyses. Some of our rules use nontrivial facts from continuous mathematics, and let us replace requirements about integrals in the assumptions, such as integrability of functions defined in terms of programs’ denotations, by the conditions involving differentiation or boundedness, which are much easier to prove automatically (and manually) than the original requirements.
Following our general methodology, we have developed a static program analysis for the Pyro programming language that can discharge the assumption of the inference algorithm about so called model-guide pairs. In Pyro and other deep probabilistic programming languages, a program denoting a model typically comes with a companion program, called guide, decoder, or inference network. This companion, which we call guide, helps the inference algorithm to find a good approximation to what the model ultimately denotes under a given dataset (i.e., the posterior distribution of the model under the dataset); the algorithm uses the guide to fix the search space of approximations, and solves an optimisation problem defined on that space. A model and a guide should satisfy an important correspondence property, which says that they should use the same sets of random variables, and for any such random variable, if the probability of the variable having a particular value is zero in the model, it should also be zero in the guide. If the property is violated, the inference algorithm may attempt to solve an optimisation problem with undefined optimisation objective and return parameter values that do not make any sense. Our static analysis checks this correspondence property for Pyro programs. When applied to eight representative model-guide pairs from the Pyro webpage, which include sophisticated neural network models such as Attend-Infer-Repeat (AIR), the analysis found a bug in two of these cases, and proved that the property holds in the others.
Another motivation for this paper is to demonstrate an opportunity for programming languages and verification research to have an impact on the advances of machine learning and AI technologies, in the design and implementation of models. One popular question is: what properties should we verify on machine-learning programs? Multiple answers have been proposed, which led to excellent research results, such as those on robustness of neural networks (Mirman et al., 2018). But most of the existing research focuses on the final outcome of machine learning algorithms, not the process of applying these algorithms. One of our main objectives is to show that the process often relies on multiple assumptions on models and finding automatic ways for discharging these assumptions can be another way of making PL and verification techniques contribute. While our suggested solutions are not complete, they are intended to show the richness of this type of problems in terms of theory and practice.
We summarise the contributions of the paper:
We formally express the behaviour of the most fundamental variational inference algorithm on probabilistic programs using denotational semantics, and identify requirements on program denotations that are needed for this algorithm to work correctly.
We describe conditions that imply the identified requirements but are easier to prove. The sufficiency of the conditions relies on nontrivial results from continuous mathematics. We sketch a recipe for building program analyses for checking these conditions automatically.
We present a static analysis for the Pyro language that checks the correspondence requirement of model-guide pairs. The analysis is based on our recipe, but extends it significantly to address challenges for dealing with features of the real-world language. Our analysis has successfully verified 6 representative Pyro model-guide examples, and found a bug in two examples.
2. Variational Inference and Verification Challenges by Examples
We start by explaining informally the idea of stochastic variational inference (in short SVI), one fundamental SVI algorithm, and the verification challenges that arise when we use this algorithm.
2.1. Stochastic variational inference
In a probabilistic programming language, we specify a model by a program. The program model() in Figure 1(a) is an example. It describes a joint probability density on two real-valued random variables and . The value of the former is not observed, while the latter is observed to have the value . Finding out the value of is the objective of writing this model. The joint density is expressed in terms of prior and likelihood in the program. The prior of is the normal distribution with mean and standard deviation , and it expresses the belief about the possible value of before any observation. The likelihood is a normal distribution whose mean and standard deviation are either or depending on the sign of the value of . The purpose of most inference algorithms is to compute exactly or approximately the posterior density given a prior and a likelihood. In our example, the posterior is:
Intuitively, the posterior expresses an updated belief on upon observing . The dashed blue and solid orange lines in Figure 1(b) show the prior and posterior densities, respectively. Note that the density of a positive in the prior went up in the posterior. This is because when , the mean of is , a value closer to the observed value than the alternative for the mean.
SVI algorithms approach the posterior inference problem from the optimisation angle. They consider a collection of approximating distributions to a target posterior, formulate the problem of finding a good approximation in the collection as an optimisation problem, and solve the optimisation problem approximately. The solution becomes the result of such an algorithm. In Pyro, the users specify such a collection by a single parameterised program called guide; the collection can be generated by instantiating the parameters with different values. The program guide() in Figure 1(a) is such an example. It has a real-valued parameter (written as theta in the program), and states that the probability density of is the normal distribution with unknown mean and standard deviation . The lines 13–17 in the figure show how to apply a standard SVI engine of Pyro (called Trace_ELBO) to find a good . They instruct the engine to solve the following optimisation problem:
The optimisation objective is the KL divergence from to , and measures the similarity between the two densities, having a small value when the densities are similar. The KL divergence is drawn in Figure 1(c) as a function of , and the dotted green line in Figure 1(b) draws the density at the optimum . Note that the mean of this distribution is biased toward the positive side, which reflects the fact that the property has a higher probability than its negation in the posterior distribution.
One of the most fundamental and versatile algorithms for SVI is score estimator (also called REINFORCE). It repeatedly improves in two steps. First, it estimates the gradient of the optimisation objective with samples from the current :
where are independent samples from the distribution . Then, the algorithm updates with the estimated gradient (the specific learning rate is chosen to improve readability):
When the learning rate is adjusted according to a known scheme, the algorithm is guaranteed to converge a local optimum (in many cases) because its gradient estimate satisfies the following unbiasedness property (in those cases):
where the expectation is taken over the independent samples from .
2.2. Verification challenges
We now give two example model-guide pairs that illustrate verification challenges related to SVI.
The first example appears in Figure 2(a). It is the Bayesian regression example from the Pyro webpage (this example is among the benchmarks used in §8), which solves the problem of finding a line that interpolates a given set of points in .
The problem with this example is that the KL divergence of its model-guide pair, the main optimisation objective in SVI, is undefined. The model and guide in the figure use the random variable sigma, but they use different non-zero-probability regions, called supports, for it. In the model, the support is , while that in the guide is . But the KL divergence from a guide to a model is defined only if for every random variable, its support in the guide is included in that in the model. We point out that this support mismatch was found by our static analyser explained in §8.
Figures 2(b) and 2(c) show two attempts to resolve the undefined-KL issue. To fix the issue, we change the distribution of sigma in the guide in (b), and in the model in (c). These revisions remove the problem about the support of sigma, but do not eliminate that of the undefined KL. In both (b) and (c), the KL divergence is . This happens mainly because sigma can be arbitrarily close to 0 in the guide in both cases, which makes integrand in the definition of the KL divergence diverge to .
An SVI-specific verification challenge related to this example is how to prove the well-definedness of the KL divergence and more generally the optimisation objective of an SVI algorithm. In §6.2, we provide a partial answer to the question. We give a condition for ensuring the well-definedness of the KL divergence. Our condition is more automation-friendly than the definition of KL, because it does not impose the difficult-to-check integrability requirement present in the definition of KL.
The second example appears in Figure 3(a). It uses the same model as in Figure 1(a), but has a new guide that uses a uniform distribution parameterised by . For this model-guide pair, the KL divergence is well-defined for all , and the optimal minimising the KL is .
However, as shown in Figure 3(b), the gradient of the KL divergence is undefined for , because the KL divergence is not differentiable at and . For all the other , the KL divergence and its gradient are both defined, but the score estimator cannot estimate this gradient in an unbiased manner (i.e., in a way satisfying (1)), thereby losing the convergence guarantee to a local optimum. The precise calculation is not appropriate in this section, but we just point out that the expectation of the estimated gradient is always zero for all , but the true gradient of the KL is always non-zero for those , because it has the form: Here is the density of the normal distribution with mean and standard deviation (concretely, ). The mismatch comes from the invalidity of one implicit assumption about interchanging integration and gradient in the justification of the score estimator; see §5 for detail.
To sum up, the second example shows that even if the KL divergence is defined, its gradient is sometimes undefined, and also that even if both the KL divergence and its gradient are defined, the sample-based estimate of the gradient in a standard SVI algorithm may be biased—this means that the equation similar to (1) does not hold and a SVI algorithm is no longer guaranteed to converge to a local optimum. Proving that these failure cases do not arise is another SVI-specific verification challenge. In §6.3, we give another example of similar flavour, and provide an automation-friendly condition that ensures the existence of the KL divergence and its gradient as well as the unbiasedness of the gradient estimate of the score estimator.
3. Review of Measure Theory and Notations
A -algebra on a set is a collection of subsets of such that (i) ; (ii) and for all ; (iii) when all subsets are in . An equivalent but easier-to-remember characterisation is that is closed under boolean operations and countable union and intersection. We call the pair of a set and a -algebra measurable space, and subsets in measurable. A function from a measurable space to another measurable space is measurable if for all .
An example of measurable space is the -dimensional Euclidean space with the Borel -algebra , where is the closure operator that converts a collection of subsets of into the smallest -algebra containing the collection. Subsets of , such as , form measurable spaces with the -algebra . Another example is a set with the so called discrete -algebra on that consists of all subsets of .
A measure on a measurable space is a function from to such that and satisfies the countable additivity condition: for a countable family of disjoint measurable subsets ,
A well-known example is the Lebesgue measure on which maps each measurable subset of to its volume in the usual sense.222The Lebesgue measure is the unique measure on that sets the volume of the unit cube to and is translation invariant: for all measurable subsets and , . When , we call subprobability measure. If , we may drop “sub”, and call probability measure.
The Lebesgue integral is a partial operator that maps a measure on and a real-valued measurable function on the same space to a real number. It is denoted by . To follow the paper, it is enough to know that this integral generalises the usual Riemann integral from calculus.333Another useful fact is that when is non-negative, where the supremum is taken with respect to all finite partitions of into measurable subsets. For a measure on , if for non-negative , we say that is the density of with respect to and call reference measure.
In the paper, we use a few well-known methods for building measurable spaces.
The first method applies when we are given a set and a collection of functions to measurable spaces . The method is to equip with the smallest -algebra making all ’s measurable:
The second relies on two constructions, product and disjoint union. Suppose that we are given measurable spaces for all . We define a product measurable space that has as its underlying set and the following product -algebra as its -algebra:
The construction of the product -algebra can be viewed as a special case of the first where we consider the smallest -algebra on that makes every projection map to measurable. When the are disjoint, they can be combined as disjoint union. The underlying set in this case is , and the -algebra is
When with , we denote the product measurable space by . In addition, if and are disjoint, we write for the disjoint-union measurable space.
The third method builds a measurable space out of measures or a certain type of measures, such as subprobability measures. For a measurable space , we form a measurable space with measures. The underlying set and -algebra of the space are defined by
The difficult part to grasp is . Once again, a good approach for understanding it is to realise that is the smallest -algebra that makes the function from to measurable for all measurable subsets . This measurable space gives rise to a variety of measurable spaces, each having a subset of as its underlying set and the induced -algebra . In the paper, we use two such spaces, one induced by the set of subprobability measures on and the other by the set of probability measures.
A measurable function from to is a kernel. If is a subprobability measure (i.e., ) for all , we say that is a subprobability kernel. In addition, if (i.e., is a probability measure) for every , we call probability kernel. A good heuristic is to view a probability kernel as a random function and a subprobability kernel as a random partial function. We use well-known facts that a function is a subprobability kernel if and only if it is a measurable map from to , and that similarly a function is a probability kernel if and only if it is a measurable function from to .
We use a few popular operators for constructing measures throughout the paper. We say that a measure on a measurable space is finite if , and -finite if there is a countable partition of into measurable subsets ’s such that for every . Given a finite or countable family of -finite measures on measurable spaces ’s, the product measure of ’s, denoted , is the unique measure on such that for all measurable subsets of ,
Given a finite or countable family of measures on disjoint measurable spaces ’s, the sum measure of ’s, denoted , is the unique measure on such that
Throughout the paper, we take the convention that the set of natural numbers includes . For all positive integers , we write to mean the set .
4. Simple Probabilistic Programming Language
In this section, we describe the syntax and semantics of a simple probabilistic programming language, which we use to present the theoretical results of the paper.
We use an extension of the standard while language with primitives for probabilistic programming. The grammar of the language is given in Figure 4. Variables in the language store real numbers, but expressions may denote reals, booleans and strings and they are classified into based on these denoted values. The primitive functions for reals and for strings may be usual arithmetic and string operations, such as multiplication and exponentiation for and string concatenation for .
The grammar for includes the cases for the standard constructs of the while language, such as assignment, sequencing, conditional statement and loops. In addition, it has two constructs for probabilistic programming. The first draws a sample from the normal distribution with mean and standard deviation and naming the sample with the string . The next expresses that a sample is drawn from the normal distribution with mean and standard deviation and the value of this sample is observed to be . It lets the programmers express information about observed data inside programs. Operationally, this construct can be understood as an instruction for updating a global variable that stores the so called importance score of the execution. The score quantitatively records how well the random choices in the current execution match the observations, and the score statement updates this score by multiplying it with the density at of the appropriate normal distribution.
Consider the following program:
The program specifies a model with one random variable . Using a relatively flat normal distribution, the program specifies a prior belief that the value of the random variable is likely to be close to and lie between and . The next score statement refines this belief with one data point , which is a noisy observation of the value of (bound to ). The parameters to the normal density in the statement express that the noise is relatively small, between and . Getting the refined belief, called posterior distribution, is the reason that a data scientist writes a model like this program. It is done by an inference algorithm of the language.
Permitting only the normal distribution does not limit the type of models expressible in the language. Every distribution can be obtained by transforming the standard normal distribution.444Here we consider only Borel spaces. Using only the normal distribution has an impact on stochastic variational inference to be discussed later, because it requires a guide to use only normal distributions. But the impact is minor, because most well-known approaches for creating guides from the machine-learning literature (such as extensions of variational autoencoder) use normal distributions only, or can be made to do so easily.
The denotational semantics of the language just presented is mostly standard, but employs some twists to address the features for probabilistic programming (Staton et al., 2016).
Here is a short high-level overview of the semantics. Our semantics defines multiple measurable spaces, such as and , that hold mathematical counterparts to the usual actors in computation, such as program stores (i.e., mappings from variables to values) and states (which consist of a store and further components). Then, the semantics interprets expressions and commands as measurable functions of the following types:
Here is the measurable space of reals with the Borel -algebra, and and are discrete measurable spaces of booleans and strings. and are measurable spaces for stores (i.e., maps from variables to values) and states which consist of a store and a part for recording information about sampled random variables. Note that the target measurable space of commands is built by first taking the product of measurable spaces and and then forming a space out of subprobability measures on . This construction indicates that commands denote probabilistic computations, and the result of each such computation consists of an output state and a score which expresses how well the computation matches observations expressed with the score statements in . Some of the possible outcomes of the computation may lead to non-termination or an error, and these abnormal outcomes are not accounted for by the semantics, which is why for a state is a subprobability distribution. The semantics of expressions is much simpler. It just says that expressions do not involve any probabilistic computations, so that they denote deterministic measurable functions.
We now explain how this high-level idea gets implemented in our semantics. Let be a countably infinite set of variables. Our semantics uses the following sets:
A state consists of a store and a random database . The former fixes the values of variables, and the latter records the name (given as a string) and the value of each sampled random variable. The domain of is the names of all the sampled random variables. By insisting that should be a map, the semantics asserts that no two random variables have the same name. For each state , we write and for its store and random database components, respectively. Also, for a variable , a string and a value , we write and to mean and .
We equip all of these sets with -algebras and turn them to measurable spaces in a standard way. Note that we constructed the sets from by repeatedly applying the product and disjoint-union operators. We equip with the usual Borel -algebra. Then, we parallel each usage of the product and the disjoint-union operators on sets with that of the corresponding operators on -algebras. This gives the -algebras for all the sets defined above. Although absent from the above definition, the measurable spaces and equipped with discrete -algebras are also used in our semantics.
We interpret expressions as measurable functions , , and , under the assumption that the semantics of primitive real-valued of arity and string-valued of arity are given by measurable functions and . It is standard, and we describe it only for some sample cases of and :
Lemma 4.1 ().
For all expressions , , and , their semantics , and are measurable functions from to , and , respectively.
We interpret commands as measurable functions from to , i.e., subprobability kernels from to . Let be the set of subprobability kernels from to , and be the -algebra of the product space . We equip with the following partial order: for all , if and only if . The following lemma is a minor adaptation of a known result.
Lemma 4.2 ().
is an -complete partial order with the least element .
The semantics of a command is defined in Figure 5.
The interpretation of the loop is the least fixed point of the function on the -complete partial order . The function is continuous, so that the least fixed point is obtained by the -limit of the sequence . In the definition of , the argument is a sub-probability kernel, and it represents the computation after the first iteration of the loop. The semantics of the sample statement uses an indicator function to exclude erroneous executions where the argument denotes a name already used by some previous random variable, or the standard deviation is not positive. When this check passes, it distills to a property on the value of and computes the probability of the property using the normal distribution with mean and standard deviation .
Theorem 4.3 ().
For every command , its interpretation is well-defined and belongs to .
4.3. Posterior inference and density semantics
We write a probabilistic program to answer queries about the model and data that it describes. Among such queries, posterior inference is one of the most important and popular. Let be the initial state that consists of some fixed store and the empty random database. In our setting, posterior inference amounts to finding information about the following probability measure for a command . For a measurable ,
The probability measure is called the posterior distribution of , and the unnormalised posterior distribution of (in and , we elide the dependency on to avoid clutter). Finding information about the former is the goal of most inference engines of existing probabilistic programming languages. Of course, is not defined when the normalising constant is infinite or zero. The inference engines regard such a case as an error that a programmer should avoid, and consider only without such an error.
Most algorithms for posterior inference use the density semantics of commands. They implicitly pick measures on some measurable spaces used in the semantics. These measures are called reference measures, and constructed out of Lebesgue and counting measures (Bhat et al., 2012, 2013; Hur et al., 2015). Then, the algorithms interpret commands as density functions with respect to these measures. One outcome of this density semantics is that the unnormalised posterior distribution of a command has a measurable function such that , where is a reference measure on . Function is called density of with respect to .
In the rest of this subsection, we reformulate the semantics of commands using density functions. To do this, we need to set up some preliminary definitions.
First, we look at a predicate and an operator for random databases, which are about the possibility and the very act of merging two databases. For , define the predicate by:
When , let be the random database obtained by merging and :
Lemma 4.4 ().
For every measurable , the function from to is measurable.
Second, we define a reference measure on :
where is the Lebesgue measure on . As explained in the preliminary section, the symbol here represents the operator for constructing a product measure. In particular, refers to the product of the copies of the Lebesgue measure on . In the above definition, we view functions in as tuples with real components and measure sets of such functions using the product measure . When is the empty set, is the nullary-product measure on , which assigns to and to the empty set.
The measure computes the size of each measurable subset in three steps. It splits a given into groups based on the domains of elements in . Then, it computes the size of each group separately, using the product of the Lebesgue measure. Finally, it adds the computed sizes. The measure is not finite, but it satisfies the -finiteness condition,555The condition lets us use Fubini theorem when showing the well-formedness of the density semantics in this subsection and relating this semantics with the measure semantics in the previous subsection. the next best property.
Third, we define a partially-ordered set with certain measurable functions. We say that a function uses random databases locally or is local if for all , , and ,
The condition describes the way that uses a given random database , which plays the role of a bank of random seeds (that may partially consume as it needs random values). Some part of may be consumed by , but the unconsumed part of does not change and is returned in the output. Also, the behaviour of does not depend on the unconsumed . We define the set by
Here we view and as measurable spaces equipped with discrete and Borel -algebras. Also, we regard and as measurable spaces constructed by the product and disjoint-union operators on measurable spaces, as explained in §3.
The locality in the definition of formalises expected behaviours of commands. In fact, as we will show shortly, it is satisfied by all commands in our density semantics. This property plays an important role when we establish the connection between the density semantics in this subsection and the standard measure semantics in §4.2.
The functions in are ordered pointwise: for all ,
Lemma 4.5 ().
is an -complete partial order and has the least element . Thus, every continuous function on has a least fixed point (and this least fixed point is unique).
For each , let be the following lifting to a function on :
This lifting lets us compose two functions in .
Using these preliminary definitions, we define a density semantics in Figure 6, where a command means a function . The notation in the figure means the removal of the entry from the finite map if ; otherwise, is just . The set membership says that is a local measurable function from to . Thus, the function takes a store and a random database as inputs, where the former fixes the values of variables at the start of and the latter specifies random seeds some of which may consume to sample random variables. Given such inputs, the function outputs an updated store , the part of not consumed by , the total score expressing how well the execution of matches observations, and the probability density of at the consumed part of . If does not contain enough random seeds, or the execution of encounters some runtime error, or it falls into an infinite loop, then the function returns .