The decoupled extended Kalman filter for dynamic exponential-family factorization models

The decoupled extended Kalman filter for dynamic exponential-family factorization models

Carlos A. Gómez-Uribe
cgomez@alum.mit.edu
   Brian Karrer
briankarrer@fb.com
Facebook
Abstract

We specialize the decoupled extended Kalman filter (DEKF) for online parameter learning in factorization models, including factorization machines, matrix and tensor factorization, and illustrate the effectiveness of the approach through simulations. Learning model parameters through the DEKF makes factorization models more broadly useful by allowing for more flexible observations through the entire exponential family, modeling parameter drift, and producing parameter uncertainty estimates that can enable explore/exploit and other applications. We use a more general dynamics of the parameters than the standard DEKF, allowing parameter drift while encouraging reasonable values. We also present an alternate derivation of the regular extended Kalman filter and DEKF that connects these methods to natural gradient methods, and suggests a similarly decoupled version of the iterated extended Kalman filter.

1 Introduction

Regression, matrix and tensor factorization, factorization machines, and other statistical models can be viewed as variations of a general model with exponential family observations. This view generalizes factorization models to a broader class of observation distributions than has been considered.

We show that an approximate Gaussian posterior of the parameters for this general class of models can be learned online, even when the parameters drift over time, through the extended Kalman filter (EKF). Modeling parameter drift can be desirable in situations where the underlying data is non-stationary, as if often the case in recommender systems, e.g., where user preferences change over time.

Maintaining a full covariance matrix of the parameters, as prescribed by the EKF, can be prohibitive in terms of memory and computation. The DEKF, introduced in [22] to train neural networks, alleviates the memory and computational requirements of the EKF by maintaining a block diagonal approximation of the covariance matrix. Because the EKF is also related to online Fisher scoring and the online natural gradient algorithm, so is the DEKF, as discussed later.

The block-diagonal covariance approximation of the DEKF is particularly relevant for models, such as factorization models, with a large number of parameters but where only a relatively small subset of them is involved in every observation. Specifically, we assume that the model parameters can be naturally grouped into subsets we call entities111These subsets are called nodes in the original DEKF paper, but we find entity more descriptive for factorization models., such that only a relatively small number of entities are involved in each observation. E.g., in matrix factorization exactly two subsets of parameters define each observation, those for the user and the item interacting, so we can let each user and each item correspond to an entity.

We show that the DEKF only requires updating the parameters of entities involved in an observation, leading to a particularly efficient implementation of the DEKF for factorization models. Because the DEKF produces a posterior distribution of the parameters, it also enables applications that require uncertainty estimates, e.g., where explore/exploit trade-offs are important. To the best of our knowledge, however, the DEKF has not been applied to factorization models before.

The DEKF we present here is different from the standard DEKF in several ways. First, we specialize it to exponential family models, motivated by models with typically few entities per observation. Second, the standard DEKF was formulated for static parameters, or for parameters that undergo a simple random walk. The latter choice often results in parameter values that become too large and lead to badly behaved models. Here, we consider a more general dynamics of the parameters that allows for parameter drift while encouraging parameters to take on reasonable values. The particular dynamics we choose allows for lazy posterior updates, and requires augmenting the space of parameters. To keep our paper self-contained, we assume no familiarity with the DEKF or even Kalman filtering in general.

Our paper is organized as follows. Section 2 introduces the general model we study, and shows several factorization models are special cases of this general model. Section 3 describes our DEKF, and specializes it to regression, matrix and tensor factorization, and factorization machines. Section 4 discusses connections of the EKF and DEKF to other methods, including an iterated version of the EKF. We also show that one can similarly motivate and obtain a decoupled iterated EKF. Section 5 describes numerical simulation results obtained from the application of our DEKF to a variety of models and tasks, including explore/exploit. Section 6 concludes with a discussion about limitations, and suggests possible research directions.

2 Exponential Family Observation Model

The observation model we consider is a generalization of the Generalized Linear Model (GLM), e.g., see [7, 18], where we let the mapping from the model parameters into the so-called signal be different for different kinds of models.

Consider a model with parameters for a positive integer , that produces an observation , where is another positive integer. Typically, We assume that has a probability distribution in the natural exponential family, with log-likelihood

(1)

Here is the natural parameter of the distribution, is a positive definite matrix that is a nuisance parameter, and the symbol denotes the matrix transpose operation. The nuisance parameter, and the functions and , follow from the specific member of the exponential family, and are all assumed known. Generally the nuisance parameter is the identity matrix, though in linear regression with known covariance, is the covariance of the observations.222The exponential family generally uses instead of where indicates the vector of sufficient statistics for an underlying vector of observations . To avoid additional notation, we consider our observation vector to just be the vector of sufficient statistics.

We denote the mean and covariance of given by and , though we may omit the dependence on for improved readability. For distributions in the form of Equation 1, it can be shown that

(2)
(3)

where is called the response function.

To connect the observations to the model parameters, we assume that is a deterministic and possibly non-linear function of , with finite second derivatives. Often, is also a function of context denoted by e.g., the covariates in the case of regression, or the indices corresponding to the user and item involved in an observation for matrix factorization. It is typical and helpful to think of an intermediate and simple function of and that the natural parameter is a function of, i.e., . This intermediate function is called the signal, and outputs values in . To avoid notation clutter, we suppress all dependencies on .

We will often need to evaluate the mean and variance of , as well as other functions, for specific values of , e.g., for for some arbitrary vector . Abusing notation for improved readability, we will write and instead of and to denote the mean and covariance of when .

The model also needs an invertible function called the link function that maps the signal to so . Depending on the family, can have a restricted range of values (e.g. ), and for ease of exposition, we only consider link functions that obey these ranges without restricting the signal.

A particularly useful choice for the link function is the canonical link function that makes and simplifies relevant mathematics. Because the specific distribution within the exponential family determines , different distributions have different canonical links.

To summarize, determines but only through the signal . Then determines the mean and covariance of given via Equations 2 and 3.

2.1 Model Examples

Different important model classes only differ in the mapping from to , as shown below for some examples.

  1. The GLM. It is obtained by setting , where is a matrix of predictors. This model includes linear and non-linear regression models.

    The EKF has already been applied to the GLM with dynamic parameters, e.g., see [5]. But the DEKF can further enable learning for GLM models with many parameters and sparse .

  2. Matrix factorization (MF). Consider a set of entities referred to as users and items, each described by a vector in for some small , and let consist of the stacking of all user and item vectors.

    Assume that each observation is univariate, and describes the outcome of the interaction between user number and item number . Let the user and item selector matrices and in be such that they act on to pull out the user and item vectors and . The signal in these models is quadratic in , and is given by where . Sometimes bias terms for each user, item, or overall are added to the signal too, but we leave them out for simplicity of exposition.

    MF models typically assume that the observations are Gaussian, or occasionally Bernoulli, e.g., see [17, 12], so our setup generalizes these models to observations in other exponential family distributions that can be more natural modeling choices for different kinds of data. In addition, applying the DEKF to these models allows for user and item vector drift, and enables applications, such as explore/exploit, that need the uncertainty of parameter estimates.

  3. Tensor factorization (TF). The CANDECOMP / PARAFAC (CP) decomposition of an order- tensor [11] has entities for each of the dimensions of the tensor. When , the model is equivalent to MF with two kinds of entities, users and items. Each entity in a TF model is described by a vector in for a small and is associated with one of the modes, e.g., users for mode one and items for mode two when . Similarly, consists of stacking all these vectors together. Each observation is univariate, and describes the interaction between entities, one per mode. Denote the corresponding entity vectors involved in the observation by . The signal is defined as . Note that when the signal is the same as in MF models. Our setup offers similar advantages in TF models as in MF models: more flexible observations, parameter drift, and uncertainty estimates.

  4. Factorization machines (FM). These models, introduced in [23], typically have univariate responses, and include univariate regression, MF, and tensor models as special cases.

    Assume there are entities, e.g., user or items that can be involved in any of the observations, and let be non-zero only when entity is involved in the observation, with . Let be the parameters corresponding to entity

    In a factorization machine (FM) of order , where and with a positive integer, so . Then the signal becomes

    (4)

    When has exactly two non-zero entries set to 1, then the above becomes identical to MF, with a user, item and a general bias term. An FM model of order has for all entities, so the signal reduces to that of the GLM.

    More generally, in a FM of order , each entity is described by vectors where for all and where we define (so ). Then the signal becomes

    (5)

    Here is the -th entry in .

    The parameters of entity are then obtained by stacking the vectors into , so where Lastly, let be another entity, introduced just to simplify notation, that contains the general bias term. Then all the model parameters can be collected into so

    FMs are learnt via stochastic gradient descent, Markov Chain Monte Carlo, or alternating least squares or coordinate ascent [24]. Our treatment extends FMs beyond Bernoulli and Gaussian observations, allows for dynamic parameters, and provides parameter uncertainty estimates.

  5. Other Models. There are other important statistical models that are also described by the general setup in Section 2, have a large number of parameters that can be grouped into entities, and where each observation also depends only on a small number of entities. We expect the DEKF to apply to such models too. However, there are also important models where most parameters are involved in every observation, and for which the block diagonal approximation of the covariance that the DEKF makes is less valid. Such is generally the case for feed-forward neural networks with a dense architecture and non-sparse inputs, despite these models being the initial motivation behind the DEKF. Finally, many important statistical models are not described by our general model and are beyond the scope of this paper, e.g., models with latent variables like mixture and topic models.

3 The Decoupled Extended Kalman Filter

The Kalman filter was initially introduced in [10] for state estimation in linear systems driven by Gaussian noise, and with observations that depend linearly on the state and on more Gaussian noise. Since then, many variants of the Kalman filter have been developed and applied to a wide variety of situations, including parameter learning. See [27, 8] for good overviews of Kalman filters; the latter focused on neural network applications.

The EKF is a variant of the Kalman filter for non-linear dynamics and non-linear observations. Like the standard Kalman filter, the EKF has two steps: an update step that incorporates a new observation into the parameter estimates, and a predict step for models with time-varying parameters. We first describe the update step for the EKF next, and then show how this step simplifies in the DEKF. We then describe the predict step for our variant of the DEKF, which unlike the EKF predict step, assumes that the parameters of different entities follow independent dynamics of a particular form.

3.1 The EKF Update Step

We assume that i.e., that the parameters have a Gaussian prior. The EKF computes an approximate Gaussian posterior for the parameters

First, define the auxiliary matrix function

(6)

Here is the derivative of the natural parameter with respect to . Given a value of , .

The mean and covariance of the approximate Gaussian posterior are then found via:

(7)
(8)

Here denotes evaluated at , and we use that notation elsewhere too for some function evaluations. Note that the matrix in the square brackets above, whose inverse is needed, is only of size -by-. Also, we call the (predictive) error evaluated at , so we see that the update to the mean in Equation 7 is proportional to .

In models with univariate observations, the EKF update equations simplify considerably, e.g., the necessary inverse becomes a simple fraction, and becomes the scalar . Letting denote the variance of when , the update equations become

(9)
(10)

Note that here i.e., it is a row vector.

Specializing these equations to a concrete model requires specifying the distribution of the observation, and the link function, to determine , , , and The latter is needed to compute . The last quantity, comes from the specific model being used, e.g., regression, MF, etc.

3.2 The DEKF Update Step

The DEKF is particularly relevant for applications where each observation involves a relatively small number of entities. Over time, of course, we expect all entities to be involved in multiple observations.

Our first goal for this section is to show that only the parameters for entities involved in an observation need to be updated when that observation is received. Our second goal is to show how to update the remaining model parameters efficiently, which we accomplish by enforcing a block diagonal approximation of the covariance matrix with one block per entity.

Assume that the model has entities, each with parameters , with and , and let Each parameter is a part of exactly one entity, so Without loss of generality, assume that the first entities are those involved in the current observation, i.e., that the signal is only a function of the parameters of these entities. Of course, then and are also only a function of these entities. Let and be the number of parameters involved in the observations, and in the complement set of parameters. Let be the set of parameters involved in the observation, and the rest of the parameters, so that

We then assume that the prior estimates of different entities are uncorrelated, i.e., to only model correlations within entities. The covariance of the prior is then block diagonal, with in its i-th diagonal block. Let be the prior mean of Because is Gaussian, this implies that the parameters of different entities are mutually independent, i.e., that with .

From Bayes theorem, we know that

(11)

where is the Gaussian prior of , and where is the posterior. Substituting the factorized form for the prior into Equation 11 yields the posterior for

(12)

Here the last equality relies on not being a function of Equation 12 implies that This result has important consequences.

First, the posterior of the entities involved in the observation is independent from the posterior of the rest of the entities, and is proportional to . In addition, the posterior of the rest of the entities is identical to the prior, i.e., This last observation is crucial to make our algorithm efficient, since it means that we only need to to evaluate the update Equations 7 and 8 for the set of parameters . The estimates for the rest of the parameters remain unchanged from their prior estimates.

Another important consequence is that we can add new entities as they appear, which can be crucial for some online settings, e.g., in recommender systems where new users and items appear all the time. The parameters for entities that have not been involved in any observations can just be appended into the set of parameters when the entity is first observed.

Data: Observation , context , prior mean , prior covariance
Result: , , updated in-place
1 Let be the entities involved in the observation. Initialize . for  in  do
2      
3Let , , , and . for  in  do
4      
return ,
Algorithm 1 DEKF for models with static parameters.

Entities also help to speed up the evaluation of Equations 7 and 8. Note that these equations require the computation of and evaluated at . Because only the first entities are involved in the observation, we have that

(13)

where is a matrix with entries set to zero of dimensions . Combined with the block diagonal structure of , this yields

(14)
(15)
(16)

where is again defined to have the appropriate dimensions.

Evaluating the expressions above at leaves little extra work to compute the updated parameters and The resulting posterior covariance however, is not typically block diagonal.

Letting denote the updated block for entities and in the observation, we have that

(17)

Typically will be non-zero for any pair of entities and involved in the observation, even when the corresponding block in the prior is zero, i.e., when . So over time, the covariance of the posterior would have more non-zero blocks.

So we need to approximate the covariance of the posterior, to retain the block diagonal structure we want to maintain. To accomplish this, we simply zero out any off-diagonal non-zero blocks. In practice, we simply never compute off-diagonal blocks. This finishes the update step for the DEKF that reflects the new observation in the parameter estimates. For models with static parameters, the DEKF only has an update step, resulting in Algorithm 1.

Data: Observation , context , time , prior mean , prior covariance , time per entity ,
Result: , , updated in-place
Let be the entities involved in the observation. /* --- Predict step --- */
1 for  in  do
2       if entity exists then
3             , ,
4       else
5             , , , ,
6      
/* --- Update step --- */
7 Initialize for  in  do
8       , ,
9Let , , , and . for  in  do
10       , , , ,
return , ,
Algorithm 2 Our DEKF variant.

As shown, the covariance update in Equation 8 will densify the covariance matrix, filling in non-zero covariance blocks across entities as observations accumulate. A good choice of entities will result in few, and small in dimension and in magnitude off-diagonal blocks being non-zero in the full update. These are the blocks that the DEKF zeroes out to maintain the block diagonal covariance.

We therefore suggest that reasonable entities to use, within the capabilities of available memory and computation, are commonly co-occurring non-zero parameter components of the gradient of the natural parameter. Entity identification can be empirical based on tracking these co-occurrences, but for many models, reasonable entities can be inferred directly from the model structure. However, when memory or computation are severely limited, each parameter can define an entity, resulting in a fully diagonal covariance matrix. On the other extreme, all parameters can define a single entity when parameters are few relative to the memory and computation requirements.

For completeness, with parameters per entity such that there are parameters total, the memory storage and computation per observation for the extended Kalman filter is and . For the DEKF, this is and .

3.3 Parameter Dynamics

With parameter dynamics, parameter estimates need to be changed between observations to reflect these dynamics, resulting in the so-called predict step of Kalman filtering. In typical engineering and scientific applications of Kalman filtering, the parameters, considered as a system state, are assumed to undergo known linear dynamics plus Gaussian noise according to

(18)

where is additive Gaussian noise, and where the dynamics matrix and the vector are known. In the EKF (and the original DEKF), the true dynamics are defined by non-linear functions, that are approximated through a first order Taylor expansion about the mean of the current posterior, resulting in essentially the same linear dynamics above.

For our purposes, the standard Kalman filter dynamics are too general, since the parameters that specify the dynamics, e.g., and are typically unknown in machine learning applications. We consider parameter dynamics here only as a means to incorporate non-stationarity of the data. So we assume that each entity evolves according to the following specialized dynamics:

(19)

The initial conditions for the dynamics will be specified below. In Equation 19, is the driving Gaussian noise, with known covariance . The memory parameter , is a number between and rather than a matrix. In addition, we make the reference vectors about which the model parameters drift, a random variable that we estimate jointly with the model parameters.

Equation 19 implies the steady-state distribution . Our motivation for adding random reference vectors, and the memory parameter is two-fold. First, if , the entity parameters undergo a random walk, and can accumulate a large covariance. E.g., in MF models such a random walk often produces users and items vectors that produce absurdly large signals. Second, these dynamics allow us to predict a reasonable mean, i.e., the reference vector, for entities that have not been observed in a long time.

We let the reference vector have prior with mean and covariance and assumed known. We also extend to include the reference vectors as additional model parameters. We let denote the reference mean parameters for entity , the covariance between and , and the covariance matrix of .

Algorithm 1 still holds for the extended set of model parameters that include However, this algorithm is now inefficient since the gradient of the log-likelihood with respect to the reference vectors is zero, because . Our full variant DEKF, in Algorithm 2, modifies the update step of Algorithm 1 to remove this inefficiency. I.e., in Algorithm 2, is just the gradient of with respect to the entity’s current parameters .

Accounting for the dynamics in the posterior over the parameters is known as the predict step for the extended Kalman filter. Importantly, the predict step is only required to utilize our posterior for a specific set of entities. In particular, the predict step must be applied immediately before the update step for a set of entities in an observation. As opposed to laboriously maintaining a posterior over all parameters at time , we maintain a lazy posterior over each entity by recording only the most recent posterior for each entity, and the last time that entity was updated. This is identical to an inference procedure that would update the posterior for all entities at every time step.

Consider a particular entity . When we predict for this entity at time , we first check whether we already have a past mean and covariance for the parameters of this entity. If not, we assume the current parameters are drawn from the steady-state distribution of the dynamics and set the means and covariances to

(20)

and

(21)

If entity has a posterior that was last updated at time , we can write down the entire dynamics for the corresponding parameters between time and the current time , as

(22)

This means we can directly update the entity’s posterior at time to the posterior at time . For the means, we have

(23)

For the covariances, we have

(24)

Because the predict step for entities can predict across any number of discrete time-steps with the same computational cost, our particular choice of entity dynamics allows us to incorporate parameter drift efficiently.

We summarize the complete algorithm with the predict-update cycle in Algorithm 2.

3.4 Model Examples

To specialize the algorithm to a model, we find and substitute this into the procedures above. The first term, follows from the choice of link function used. The second term, is the gradient of the signal with respect to an entity, and depends on the model class. It is easy to write it down explicitly for the models we consider.

  1. Multivariate regression. The simplest model class we consider is regression, where , so that , and where is the subset of rows of corresponding to the entities involved in the observation 333A similar algorithm for regression was developed in some detail in [5], but without using the concept of entities and reference vectors, and working with the Hessian of the log-likelihood rather than with the Fisher information matrix. As discussed later, the latter only changes the algorithm for regression models when non-canonical links are used.

  2. Univariate regression. Here , and . Similarly, , where is the subset of that contains the predictors the correspond to entity .

  3. Matrix factorization. Since , we have that and . For all other entities, .

  4. Tensor factorization. Assume entities are involved in the observation. We then have that The signal gradient is then for and for other entities .

  5. Factorization machines. The signal is given by Equation 5. Because we have that . Assume that only the first entities are involved in the observation. For recall that consists of stacking the vectors , so

    The -th entry of for any and can be calculated via

    (25)

    For second-order FM models, we only need the derivatives above for and . Note that for , , since , and the above equation simplifies to . Similarly, for , Equation 25 simplifies to

    (26)

4 The EKF And Related Algorithms

There are several algorithms that are related to the EKF and DEKF. In this section, we start with an alternative derivation of the EKF that is useful to more directly compare it with other algorithms, including an iterated version of the EKF that can be helpful when the EKF approximations are not valid.

4.1 Deriving the EKF

Our goal here is to derive the EKF update step, in Equations 7 and 8, for our general model. A standard derivation of the EKF goes as follows. First, is approximated as a Gaussian according to Notice that the variance is evaluated at the mean of the prior, while the mean is allowed to depend on To make a quadratic function of then is approximated through a first-order Taylor expansion around . This then yields the same update described in Equations 33 and 34 after some algebra. Alternatively, one can view the EKF as the linear minimum-squared error estimator of given after linearizing and about .

We show next a different derivation of the EKF that brings connections to other methods and statistical concepts more directly. Consider Equation 11. In general, is not a quadratic function of like , so the true posterior is not Gaussian. So we proceed as follows: we approximate with a quadratic function of through a second-order Taylor expansion about the prior mean . We then take the expectation of the corresponding Hessian over the distribution of given to guarantee that the covariance matrix remains positive definite. Lastly, we do some algebra to obtain the desired equations.

To start, we note that

(27)

The (conditional) Fisher information matrix of plays a prominent role in our derivation. It is given by

(28)

where the first two equalities are essentially definitions, and the last equality is specific to our model assumptions. We use the notation to emphasize that this expectation is over samples of from the statistical model with parameters . For clarity, the natural parameter , and signal , can be functions of context that accompanied the observation .444The true Fisher information matrix is an average over the unknown distribution of contexts and over the model distribution for given and . The above Fisher information is the (conditional) Fisher information considered for a fixed context .

The Hessian of the log-likelihood is

(29)

an explicit function of the Fisher information matrix. The first term in the last equation is a negative definite matrix. The second term will be dropped shortly. It is a matrix that involves a tensor with -th entry given by:

(30)

The matrix in Equation 30 is not necessarily negative definite, and we will see below that this could result in invalid covariance matrices that are not positive definite. But there are several ways to set it to zero. The more general one, and the one we use, is to replace the Hessian in Equation 29 by its average over given , i.e., by . This is consistent with Equation 29, which uses only in the second term on the right, through , and the error averaged over given is zero. In this sense, the Hessian in Equation 29 can be seen as a sample of the Fisher information matrix for a value of .

Equation 30 also evaluates to zero for any regression model that uses the canonical link. In general, we have that For any model where the canonical link is used, the signal is the natural parameter because of the canonical link, so For regression models, the second term on the right is also zero, because the signal is linear in . So for regression models with the canonical link, the Hessian is identical to the negative Fisher information matrix, so we could have used the Hessian directly to obtain the same parameter updates.

Combining these results we obtain our approximation of the log-likelihood about a reference value of :

(31)

Plugging this approximation, evaluated at into Equation 11, as well as writing the Gaussian prior of explicitly, while dropping terms independent of yields

(32)

with

(33)
(34)

The last equality in Equation 32 is obtained by completing squares. The result shows that the approximate posterior distribution is also Gaussian with mean and covariance .

The variance update in Equation 8 follows from applying the Woodbury identity, e.g., see [21], to Equation 33, and some re-arrangement. Plugging Equation 8 into Equation 34 yields the mean update in Equation 7, also after some re-arrangement.

4.2 The Iterated EKF

Different second order approximations of will result in update equations different from the EKF. For example, consider approximating about an arbitrary value , rather than about

(35)

Working through the rest of the EKF derivation in the same way as before results in the following update equations:

(36)
(37)

Note that the column vector that multiplies on the right to determine the mean update now has two terms, and the second term goes to zero when Also note that Equation 36 may lead to a “covariance” that is not positive-definite, or even worse, singular. Using the Fisher information matrix, like the EKF does, instead of the Hessian, is one alternative, and results in the update

(38)
(39)

Now consider the reference point that is self-consistent, i.e., that results in . Under these circumstances, we get from Equation 39 that

(40)

Therefore, a self-consistent is a stationary point of the total log-likelihood. In particular, the maximum-a-posteriori (MAP) estimate of satisfies this equation.

The iterated extended Kalman filter (IEKF) computes a MAP estimate by iterating

(41)

using a line-search for the step size to ensure that the likelihood is increasing on each iteration [28]. Upon convergence, the updated mean is and the updated covariance comes from Equation 38 evaluated at the converged .555Technically the second-to-last is generally used for the covariance, and the relevant terms have already been computed.

After applying the Woodbury identity and some re-arrangement, can be written for our exponential family models as