Recursive Variational Bayesian Dual Estimation for Nonlinear Dynamics and Non-Gaussian Observations
State space models provide an interpretable framework for complex time series by combining an intuitive dynamical system model with a probabilistic observation model. We developed a flexible online learning framework for latent nonlinear state dynamics and filtered latent states. Our method utilizes the stochastic gradient variational Bayes method to jointly optimize the parameters of the nonlinear dynamics, observation model, and the recognition model. Unlike previous approaches, our framework can incorporate non-trivial observation noise models and infer in real-time. We test our method on point process observations driven by continuous attractor dynamics, demonstrating its ability to recover the phase portrait, filtered trajectory, and produce long-term predictions for neuroscience applications.
Given observed time series and input , can we build a state space model that captures the underlying nonlinear dynamics responsible for its generation Haykin1998 ()? More specifically, we would like to identify a continuous nonlinear process that captures the temporal structure, and an instantaneous noisy observation process:
where and are continuous functions, and denotes a probability distribution. Such continuous state space model is natural in many applications where the changes are slow and follow physical laws and constraints (e.g., object tracking). Our target application is neural data analysis, where inferring state space models underlying neural activity can provide insights into neural dynamics Breakspear2017 (); Kao2015b (); Paninski2009 (); Yu2009 (), neural computation Mante2013 (); Sussillo2013 (); Zhao2016d (), and development of neural prosthetics and treatment through feedback control ODoherty2011 (); Little2012 ().
It is often more convenient to formulate the state dynamics in discrete time with noise due to the uniformly sampled time series:
where captures unobserved perturbations of the state . The noise in dynamics and observation are essential in finding a concise description (low-dimensional phase portrait) of the observations Roweis2001 ().
If the nonlinear state space model is fully specified, Bayesian methods can estimate the latent states (either the filtering distribution or the smoothing distribution ), predict future states , and predict observations for Ho1964 (); Sarkka2013 (). However, in many applications, the challenge is in learning the parameters of the state space model (a.k.a. the system identification problem). Learning both the latent state trajectory and the latent (nonlinear) state space model is known as the dual estimation problem Haykin2001 (). Expectation maximization (EM) based methods have been widely used in practice Ghahramani1999 (); Valpola2006 (), and more recently variational autoencoder methods Archer2015a (); Krishnan2015 (); Krishnan2017a (); Johnson2016 (); Karl2017 (); Watter2015 () have been proposed, all of which are for the offline batch analysis.
In this paper, we are interested in real-time signal processing and state-space control setting where we need online algorithms that can recursively solve the dual estimation problem on streaming observations. A popular solution to this problem exploits the fact that online state estimators for nonlinear state space models such as extended Kalman filter (EKF) or unscented Kalman filter (UKF) can be used for nonlinear regression formulated as a state space model:
where is a function approximator parameterized by that maps to . Therefore, by augmenting the state space with the parameters, one can build an online dual estimator using nonlinear Kalman filters Wan2000 (); Wan2001 (). However, they involve coarse approximation of Bayesian filtering, involve many hyperparameters, do not take advantage of modern stochastic gradient optimizers, and are not easily applicable to arbitrary observation likelihoods. There are also closely related online version of EM-type algorithms Roweis2001 () that share similar concerns. In this paper, we derive a black-box inference framework applicable to a wide range of nonlinear state space models that is truly online meaning, that is, the computational demand of the algorithm is constant per time step.
2 Variational Principle for Online Dual Estimation
The crux of recursive Bayesian filtering is updating the posterior over the latent state one time step at a time.
where the input is omitted for brevity. Unfortunately, the exact calculations of (5) are not tractable in general, especially for a nonlinear dynamics model and a non-conjugate likelihood. We derive a recursive variational Bayesian filter by deriving a lower bound for the marginal likelihood. Let denote an arbitrary probability measure which will eventually approximate . From (5a),
where denotes the Kullback-Leibler divergence, and in the last step, we plugged in the variational posterior from the last step to form a recursive estimation.
Online variational inference is achieved by maximizing this approximate lower bound objective for the parameters of the generative model ( and ) and the variational posterior distribution provided that is estimated from the previous time step. Maximizing is equivalent to minimizing the two variational gaps: (1) the variational filtering posterior has to be close to the true filtering posterior, and (2) the filtering posterior from the previous step needs to be close to . Note that this second gap is invariant to if , that is, the one-step backward smoothing distribution is identical to the filtering distribution.
On the flip side, intuitively, there are three components in that are jointly maximized: (1) reconstruction log-likelihood which is maximized if concentrates around the maximum likelihood estimate given only , (2) the dynamics log-likelihood which is maximized if concentrates at around the maximum of , and (3) the entropy term that expands the posterior and keeps it from collapsing to a point mass.
We choose the variational posterior over the state to be a multivariate normal with diagonal covariance . To amortize the computational cost of optimization to obtain the best on each time step, we employ the variational autoencoder architecture Hinton1995 () to parametrize with a recognition model. Intuitively, the recognition model embodies the optimization process of finding , that is, it performs an approximate Bayesian filtering computation (in constant time) of (5) according to the objective function . We use a recursive recognition network model that maps and to . In particular, the recognition model is a deterministic recurrent neural network (RNN; with no extra hidden state):
We use a simple multi-layer perceptron as . Note that the Markovian architecture of the recognition model reflects the Markovian structure of filtering computation (c.f., smoothing networks often use bidirectional RNN Sussillo2016 () or graphical models Archer2015a (); Johnson2016 ()).
We use the reparameterization trick and stochastic variational Bayes Rezende2014 (); Kingma2014 (): we rewrite the expectations over as expectation over a standard normal random variable, and we use a single sample for each time step. Hence, in practice, we optimize the following objective function,
where and represents symbols sampled from and respectively. Thus, our method can handle arbitrary observation and dynamics model unlike dual form nonlinear Kalman filtering methods.
Denote the set of all parameters by from both the generative and recognition model. The objective in Eq. (8) is differentiable w.r.t. the model parameters. We optimize the objective function through gradient ascent (using Adam Kingma2014 ()) implemented within TensorFlow tensorflow2015-whitepaper (). Algorithm 1 is an overview of the recursive estimation algorithm. We outline the algorithm for a single vector time series, but we can filter multiple sequences with a common state space model simultaneously, in which case the gradients are averaged across the instantiations.
Note that this algorithm has constant time complexity per time step.
3 Application to Latent Neural Dynamics
Our primary application is real-time neural interfaces where a population of neurons are recorded while a low-dimensional stimulation is delivered Newman2015 (); ElHady2016 (). Latent state space modeling of such neural time series have been successful in describing population dynamics Zhao2016a (); Zhao2017b (). Moreover, models neural computation are often described as dynamical systems, for exmaple, attractor dynamics where the convergence to one of the attractors represents the result of computation Zhao2016d (). Here we propose a parameterization and tools for visualization of the model suitable for studying neural dynamics and building neural interfaces Zhao2016d ().
3.1 Parameterization of the generative model
We consider high temporal resolution spike train data, where each time bin has at most one action potential. Hence our observed time series is a stream of sparse binary vectors. All analysis are done with a 1 ms time bin in this study.
Our generative model assumes that the spike train observation is sampled from a probability distribution determined by the state though a linear-nonlinear map possibly together with extra parameters at each time ,
where is a point nonlinearity . We use the canonical link for point process observation in this study. Note that this model is not identifiable since where is an arbitrary invertible matrix. Also, the mean of can be traded off with the bias term in . It is straightforward to include more additive exogenous variables or history-filter for refractory period and stimulation artifacts.
We propose to use a specific parametrization for state dynamics with additive state transition function and locally linear input interaction as a special case of Eq. (2),
with centers and corresponding inverse squared kernel width .
Let denote the dimensions of observation, latent space, input, the numbers of hidden units and radial basis functions for this specific parametrization. The time complexity of our algorithm is . If we compare this to an efficient offline algorithm such as PLDS Macke2011c () run repeatedly for every new observation (online mode), its time complexity is per time step at time which grows as time passes, making it impractical for real-time application.
3.2 Phase portrait analysis
The function directly represents the velocity field of an underlying smooth dynamics (1b) in the absence of input Roweis2001 (); Zhao2016d (). We visualize the phase portrait of the estimated dynamics that consists of the vector field, example trajectories, and estimated dynamical features (namely fixed points). We numerically identify candidate fixed points that satisfy . For the simulation studies, we do an affine transformation to orient phase portrait to match the canonical equations in the main text.
4 Simulated Experiments
For the purposes of visualization, we chose to simulate from two dimensional dynamical systems, and chose neurons to generate spikes via state dynamics driven point processes. Average firing rate of the simulated spike trains were kept around 20~40 Hz.
Many theoretical models have been proposed in theoretical neuroscience to represent different modes of computation. We apply the proposed method to a two such low-dimensional models: the ring attractor model as a model of internal head direction representation, and an nonlinear oscillator as a model of rhythmic population-wide activity. We refer to their conventional formulations under different coordinate systems, but our simulations and inferences are all done in Cartesian coordinates.
The approximate posterior distribution is defined recursively in Eq. (7) as diagonal Gaussian with mean and variance determined by corresponding observation, input and previous step via a recurrent neural network. We use one hidden layer MLP in this study. Typically the state noise variance is unknown and has to be estimated from data. To be consistent with Eq. (10e), we set the starting value of to be , and hence . We initialize the loading matrix by factor analysis, and normalized every iteration to keep the system identifiable.
4.1 Ring attractor
We study the following two-variable ring attractor system
where represents the direction driven by input , and is the radial component representing the overall activity in the bump. We simulate 100 trajectories of 1000 steps with step size , , , plus Gaussian noise (). We use strong input (tangent drift) to keep the trajectories flowing around the ring clockwise or counter-clockwise respectively. We use 20 radial basis functions for dynamics model and 100 hidden units for the recognition model.
Figure 0(a) illustrates one trajectory (black) and its posterior mean (blue). These two trajectories start at green circle and diamond respectively and end at the red markers. The inference starts near the center (origin) that is relatively far from the true location because the initial posterior is set at zero mean. The final states are very close which implies that the recognition model works well. Figure 0(c) shows the reconstructed velocity field by the model. We visualize the speed as colored contour and the direction by red arrows. We can see the velocity toward the ring attractor and the speed is as lower as closer to the ring. The model also identifies a number of fixed points arranged around the ring attractor via numerical roots finding. Figure 0(d) shows the distribution of posterior means of all data in the state space. The estimated generative model has higher state noise, and also higher attraction around the ring attractor. This is also evident if we compare the simulated trajectories from the inferred dynamics (Figure 2).
Figure 3 shows the three components of (6g) and the objective lower bound clearly, demonstrating the convergence of the algorithm. We can see each component reaches a plateau within sec. As the reconstruction and dynamics log-likelihoods increase, the recognition model and dynamical model are getting more accurate while the decreasing entropy indicates the increasing confidence (inverse posterior variance) on the inferred latent states. The average computation time of a dual estimation step is ms on a machine with Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz (32 cores) and 126GB RAM (no GPU).
4.2 Nonlinear oscillator
We use a 2-dimensional relaxation oscillator with the following nonlinear state dynamics:
where is the membrane potential, is a recovery variable and is the magnitude of stimulus current in modeling single neuron biophysics Izhikevich2007 (). This model was used to model global brain state that fluctuates between two levels of excitability in anesthetized cortex Curto2009 (). We use the following parameter values , , and to simulate 100 trajectories of 1000 steps with step size and Gaussian noise (std=). At this regime, unlike the ring attractor, the spontaneous dynamics is a periodic oscillation, and the trajectory follows a limit cycle..
We use 20 radial basis functions for dynamics model and 100 hidden units for recognition model. While training the model, we do not feed input and expect the model to learn the oscillator.
We also reconstruct the phase portrait (Fig. 3(c)). The two dashed lines are the theoretical nullclines of the true model on which the velocity of corresponding dimension is zero. The reconstructed field shows a low speed valley overlapping with the nullcline especially on the right half figure. The intersection of two nullclines is a unstable fixed point. We can see the identified fixed point is close to the intersection.
We run a long-term prediction using the proposed model. The prediction contains both latent trajectory and observation which last for steps by sampling from:
given estimated parameters without seeing the data during these steps (). We give the prediction in Figure 3(e). The upper row is the true latent trajectory and corresponding observations. The light colored half are the 1000 steps before prediction and the solid colored are during prediction. We only show the observation in the prediction period. The lower row is the filtered trajectory and prediction by our proposed method. The prediction begins after the 2000 step.
One of the popular latent process modeling tools for point process observation that can make prediction is Poisson Linear Dynamical System (PLDS) Macke2011c () which assumes latent linear dynamics. We compare PLDS fit with EM on its long-term prediction on both the states and spike trains (Fig. 3(e)).
We also tried to apply the same data to the unscented Kalman filter (UKF; result not shown) for the dual estimation. However, the UKF is not flexible with arbitrary likelihood e.g. Poisson. Even when we used true dynamics equation instead of the radial basis function network, it suffers numerical singularity issues, requires fine hyperparameter tunning, and initialization close to the true value.
We proposed an online algorithm for recursive variational Bayesian inference for both system identification and filtering. Thanks to the recognition network that amortizes the optimization cost over time, our online algorithm can achieve real-time performance. Neural dynamics is highly nonlinear, and neural activity is sparse and non-Gaussian. In closed-loop neurophysiological setting, real-time adaptive algorithms are extremely valuable. It is easy to incorporate spike history filter Truccolo2005 (), and a simple model of electrical stimulation artifacts to make the proposed method more useful in practice. Furthermore, a more complex (or perhaps deep) observation model can be used without rederiving the algorithm. Our method fills the gap where previously there wasn’t an appropriate tool.
Concise description of collective neural activity can be crucial to understanding cognitive behavior. We hope that this tool enables on-the-fly analysis of high-dimensional neural spike train experiments. Clinically, a nonlinear state space model provides a basis for feedback control as a potential treatment for neurological diseases that arise from diseased dynamical states.
One weakness of the current online algorithm is its slow convergence. Since the RNN recognition model is performing an approximate Bayesian filtering, we could pretrain with true Bayesian inference examples under a pretrained generative model. We leave the issue of better recognition model architecture and initialization as future work.
- (1) S. Haykin and J. Principe. Making sense of a complex world [chaotic events modeling]. IEEE Signal Processing Magazine, 15(3):66–81, May 1998.
- (2) M. Breakspear. Dynamic models of large-scale brain activity. Nature Neuroscience, 20(3):340–352, February 2017.
- (3) J. C. Kao, P. Nuyujukian, S. I. Ryu, et al. Single-trial dynamics of motor cortex and their applications to brain-machine interfaces. Nature Communications, 6:7759+, July 2015.
- (4) L. Paninski, Y. Ahmadian, D. G. G. Ferreira, et al. A new look at state-space models for neural data. Journal of computational neuroscience, 29(1-2):107–126, August 2010.
- (5) B. M. Yu, J. P. Cunningham, G. Santhanam, et al. Gaussian-process factor analysis for low-dimensional single-trial analysis of neural population activity. Journal of neurophysiology, 102(1):614–635, July 2009.
- (6) V. Mante, D. Sussillo, K. V. Shenoy, and W. T. Newsome. Context-dependent computation by recurrent dynamics in prefrontal cortex. Nature, 503(7474):78–84, November 2013.
- (7) D. Sussillo and O. Barak. Opening the black box: Low-Dimensional dynamics in High-Dimensional recurrent neural networks. Neural Computation, 25(3):626–649, December 2012.
- (8) Y. Zhao and I. M. Park. Interpretable nonlinear dynamic modeling of neural trajectories. In Advances in Neural Information Processing Systems (NIPS), 2016.
- (9) L. R. Hochberg, M. D. Serruya, G. M. Friehs, et al. Neuronal ensemble control of prosthetic devices by a human with tetraplegia. Nature, 442(7099):164–171, July 2006.
- (10) S. Little and P. Brown. What brain signals are suitable for feedback control of deep brain stimulation in parkinson’s disease? Annals of the New York Academy of Sciences, 1265(1):9–24, August 2012.
- (11) S. Roweis and Z. Ghahramani. Learning nonlinear dynamical systems using the expectation-maximization algorithm, pages 175–220. John Wiley & Sons, Inc, 2001.
- (12) Y. Ho and R. Lee. A Bayesian approach to problems in stochastic estimation and control. IEEE Transactions on Automatic Control, 9(4):333–339, October 1964.
- (13) S. Särkkä. Bayesian filtering and smoothing. Cambridge University Press, 2013.
- (14) S. S. Haykin. Kalman filtering and neural networks. Wiley, 2001.
- (15) Z. Ghahramani and S. T. Roweis. Learning nonlinear dynamical systems using an EM algorithm. In M. J. Kearns, S. A. Solla, and D. A. Cohn, editors, Advances in Neural Information Processing Systems 11, pages 431–437. MIT Press, 1999.
- (16) H. Valpola and J. Karhunen. An unsupervised ensemble learning method for nonlinear dynamic State-Space models. Neural Computation, 14(11):2647–2692, November 2002.
- (17) E. Archer, I. M. Park, L. Buesing, J. Cunningham, and L. Paninski. Black box variational inference for state space models. ArXiv e-prints, November 2015.
- (18) R. G. Krishnan, U. Shalit, and D. Sontag. Deep Kalman filters, November 2015.
- (19) R. G. Krishnan, U. Shalit, and D. Sontag. Structured inference networks for nonlinear state space models. In Thirty-First AAAI Conference on Artificial Intelligence, February 2017.
- (20) M. Johnson, D. K. Duvenaud, A. Wiltschko, R. P. Adams, and S. R. Datta. Composing graphical models with neural networks for structured representations and fast inference. In D. D. Lee, M. Sugiyama, U. V. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems 29, pages 2946–2954. Curran Associates, Inc., 2016.
- (21) M. Karl, M. Soelch, J. Bayer, and P. van der Smagt. Deep variational Bayes filters: Unsupervised learning of state space models from raw data. In 5th International Conference on Learning Representations, 2017.
- (22) M. Watter, J. Springenberg, J. Boedecker, and M. Riedmiller. Embed to control: A locally linear latent dynamics model for control from raw images. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 2746–2754. Curran Associates, Inc., 2015.
- (23) E. A. Wan and R. Van Der Merwe. The unscented kalman filter for nonlinear estimation. In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium (Cat. No.00EX373), pages 153–158. IEEE, August 2000.
- (24) E. A. Wan and A. T. Nelson. Dual extended Kalman filter methods, pages 123–173. John Wiley & Sons, Inc, 2001.
- (25) G. E. Hinton, P. Dayan, B. J. Frey, and R. M. Neal. The "wake-sleep" algorithm for unsupervised neural networks. Science, 268(5214):1158–1161, May 1995.
- (26) D. Sussillo, R. Jozefowicz, L. F. Abbott, and C. Pandarinath. LFADS - latent factor analysis via dynamical systems, August 2016.
- (27) D. J. Rezende, S. Mohamed, and D. Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, May 2014.
- (28) D. P. Kingma and M. Welling. Auto-Encoding variational bayes. In International Conference on Learning Representation, May 2014.
- (29) M. Abadi, A. Agarwal, P. Barham, et al. TensorFlow: Large-scale machine learning on heterogeneous systems, 2015. Software available from tensorflow.org.
- (30) J. P. Newman, M.-f. Fong, D. C. Millard, et al. Optogenetic feedback control of neural activity. eLife.
- (31) A. El Hady. Closed Loop Neuroscience. Academic Press, 2016.
- (32) Y. Zhao and I. M. Park. Variational latent Gaussian process for recovering single-trial dynamics from population spike trains. Neural Computation, 29(5), May 2017.
- (33) Y. Zhao, J. Yates, and I. M. Park. Low-dimensional state-space trajectory of choice at the population level in area MT. In Computational and Systems Neuroscience (COSYNE), 2017.
- (34) E. M. Izhikevich. Dynamical systems in neuroscience : the geometry of excitability and bursting. Computational neuroscience. MIT Press, 2007.
- (35) C. Curto, S. Sakata, S. Marguet, V. Itskov, and K. D. Harris. A simple model of cortical dynamics explains variability and state dependence of sensory responses in Urethane-Anesthetized auditory cortex. The Journal of Neuroscience, 29(34):10600–10612, August 2009.
- (36) J. H. Macke, L. Buesing, J. P. Cunningham, et al. Empirical models of spiking in neural populations. In J. Shawe-Taylor, R. S. Zemel, P. L. Bartlett, F. Pereira, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems 24, pages 1350–1358. Curran Associates, Inc., 2011.
- (37) W. Truccolo, U. T. Eden, M. R. Fellows, J. P. Donoghue, and E. N. Brown. A point process framework for relating neural spiking activity to spiking history, neural ensemble, and extrinsic covariate effects. Journal of Neurophysiology, 93(2):1074–1089, February 2005.