Collapsed amortized variational inference for SNLDS

Collapsed amortized variational inference for SNLDS


We propose an efficient inference method for switching nonlinear dynamical systems. The key idea is to learn an inference network which can be used as a proposal distribution for the continuous latent variables, while performing exact marginalization of the discrete latent variables. This allows us to use the reparameterization trick, and apply end-to-end training with stochastic gradient descent. We show that the proposed method can successfully segment time series data, including videos and 3D human pose, into meaningful “regimes” by using the piece-wise nonlinear dynamics.


1 Introduction

Consider looking down on an airplane flying across country or a car driving through a field. The vehicle’s motion is composed of straight, linear dynamics and curving, nonlinear dynamics. An example is illustrated in fig. 1(a). In this paper, we propose a new inference algorithm for fitting switching nonlinear dynamical systems (SNLDS), which can be used to segment time series of high-dimensional signals, such as videos, or lower dimensional signals, such as (x,y) locations, into meaningful discrete temporal “modes” or “regimes”. The transitions between these modes may correspond to the changes in internal goals of the agent (e.g., a mouse switching from running to resting, as in Johnson et al. (2016)) or may be caused by external factors (e.g., changes in the road curvature). Discovering such discrete modes is useful for scientific applications (c.f., Wiltschko et al. (2015); Linderman et al. (2019); Sharma et al. (2018)) as well as for planning in the context of hierarchical reinforcement learning (c.f., Kipf et al. (2019)).

Figure 1: (a): Trajectory and ground truth segmentation of a particle. The direction of motion is indicated by the arrows. Blue is moving straight, yellow is turning counter-clockwise, red is turning clockwise. (c) Segmentation learned by our SNLDS model. (d) Segmentation learned by a SLDS model. Note that to model the nonlinear dynamics, the SLDS model needs to use more segments.

Extensive previous work, some of which we review in Section 2, explores modeling temporal data using various forms of state space models (SSM). We are interested in the class of SSM which has both discrete and continuous latent variables, which we denote by and , where is the discrete time index. The discrete state, , represents the mode of the system at time , and the continuous state, , represents other factors of variation, such as location and velocity. The observed data is denoted by , and can either be a low dimensional projection of , such as the current location, or a high dimensional signal that is informative about , such as an image. We may optionally have observed input or control signals , which drive the system in addition to unobserved stochastic noise. We are interested in learning a generative model of the form from partial observations, namely . This requires inferring the posterior over the latent states, , where contains all the visible variables at time . For training purposes, we usually assume that we have multiple such trajectories, possibly of different lengths, but we omit the sequence indices from our notations for simplicity. This problem is very challenging, because the model contains both discrete and continuous latent variables (a so-called “hybrid system”) and has nonlinear transition and observation models.

Figure 2: Left: Illustration of the generative model. Dashed arrows indicate optional connections. Right: Illustration of the inference network. Solid black arrows share parameters with the generative model, solid blue arrows have parameters that are unique to . The diamonds represent deterministic nodes computed with RNNs: is a bidirectional RNN applied to , and is a unidrectional RNN applied to and .

The main contribution of our paper is a new way to perform efficient approximate inference in this class of SNLDS models. The key observation is that, conditioned on knowing as well as , we can marginalize out in linear time using the forward-backward algorithm. In particular, we can efficiently compute the gradient of the log marginal likelihood, , where is a posterior sample that we need for model fitting. To efficiently compute posterior samples , we learn an amortized inference network for the “collapsed” NLDS model . Collapsing removes the discrete variables, and allows us to use reparameterization for the continuous . These tricks let us use stochastic gradient descent (SGD) to learn and jointly, as explained in Section 3. We can then use as a proposal distribution inside a Rao-Blackwellised particle filter (Doucet et al., 2000), although in this paper, we just use a single posterior sample, as is common with Variational AutoEncoders (VAEs, Kingma and Welling (2014); Rezende et al. (2014)).

Although the above “trick” allows us efficiently perform inference and learning, we find that in challenging problems (e.g., when the dynamical model is very flexible), the model uses only a single discrete latent variable and does not perform mode switching. This is a form of “posterior collapse”, similar to VAEs, where powerful decoders can cause the latent variables to be ignored, as explained in Alemi et al. (2018). Our second contribution is a new form of posterior regularization, which prevents the aforementioned problem and results in a significantly improved segmentation.

We apply our method, as well as various existing methods, to two previously proposed low-dimensional time series segmentation problems, namely a d bouncing ball, and a d moving arm. In the d case, the dynamics are piecewise linear, and all methods perform perfectly. In the d case, the dynamics are piecewise nonlinear, and we show that our method infers much better segmentation than previous approaches for comparable computational cost. We also apply our method to a simple new video dataset (see fig. 1 for an example) and sequences of human poses, and find that it performs well, provided we use our proposed regularization method.

In summary, our main contributions are

  • Learning switching nonlinear dynamical systems parameterized with neural networks by marginalizing out discrete variables.

  • Using entropy regularization and annealing to encourage discrete state transitions.

  • Demonstrating that the discrete states of nonlinear models are more interpretable.

2 Related Work

2.1 State space models

We consider the following state space model:


where is the discrete hidden state, is the continuous hidden state, and is the observed output, as in fig. 2(a). For notational simplicity, we ignore any observed inputs or control signals , but these can be trivially added to our model.

Note that the discrete state influences the latent dynamics , but we could trivially make it influence the observations as well. More interesting are which edges we choose to add as parents of the discrete state . We consider the case where depends on the previous discrete state, , as in a hidden Markov model (HMM), but also depends on the previous observation, . This means that state changes do not have to happen “open loop”, but instead may be triggered by signals from the environment. We can trivially depend on multiple previous observations; we assume first-order Markov for simplicity. We can also condition on , and on . It is straightforward to handle such additional dependencies (shown by dashed lines in fig. 2(a)) in our inference method, which is not true for some of the other methods we discuss below.

We still need to specify the functional forms of the conditional probability distributions. In this paper, we make the following fairly weak assumptions:


where are nonlinear functions (MLPs or RNNs), is a multivariate Gaussian distribution, is a categorical distribution, and is a softmax function. and are learned covariance matrices for the Gaussian emission and transition noise.

If and are both linear, and is first-order Markov without dependence on , the model is called a switching linear dynamical system (SLDS). If we allow to depend on , the model is called a recurrent SLDS (Linderman et al., 2017; Linderman and Johnson, 2017). We will compare to rSLDS in our experiments.

If is linear, but is nonlinear, the model is sometimes called a “structured variational autoencoder” (SVAE) (Johnson et al., 2016), although that term is ambiguous, since there are many forms of structure. We will compare to SVAEs in our experiments.

If is a linear function, the model may need to use many discrete states in order to approximate the nonlinear dynamics, as illustrated in fig. 1(d). We therefore allow (and ) to be nonlinear. The resulting model is called a switching nonlinear dynamical system (SNLDS), or Nonlinear Regime-Switching State-Space Model (RSSSM) (Chow and Zhang, 2013). Prior work typically assumes is a simple nonlinear model, such as polynomial regression. If we let be a very flexible neural network, there is a risk that the model will not need to use the discrete states at all. We discuss a solution to this in Section 3.3.

The discrete dynamics can be modeled as a semi-Markov process, where states have explicit durations (see e.g., Duong et al. (2005); Chiappa (2014)). One recurrent, variational version is the recurrent hidden semi-Markov model (rHSMM, Dai et al. (2017)). Rather than having a stochastic continuous variable at every timestep, rHSMM instead stochastically switches between states with deterministic dynamics. The semi-Markovian structures in this work have an explicit maximum duration, which makes them less flexible. A revised method,  (Kipf et al., 2019), is able to better handle unknown durations, but produces a potentially infinite number of distinct states, each with deterministic dynamics. The deterministic dynamics of these works may limit their ability to handle noise.

2.2 Variational inference and learning

A common approach to learning latent variable models is to maximize the evidence lower bound (ELBO) on the log marginal likelihood (see e.g., Blei et al. (2016)). This is given by where is an approximate posterior.1 Rather than computing using optimization for each , we can train an inference network, , which emits the parameters of . This is known as ”amortized inference” (see e.g., Kingma and Welling (2014)).

If the posterior distribution is reparameterizable, then we can make the noise independent of , and hence apply the standard SGD to optimize . Unfortunately, the discrete distribution is not reparameterizable. In such cases, we can either resort to higher variance methods for estimating the gradient, such as REINFORCE, or we can use continuous relaxations of the discrete variables, such as Gumbel Softmax (Jang et al., 2017), Concrete (Maddison et al., 2017b), or combining both, such as REBAR (Tucker et al., 2017). We will compare against a Gumbel-Softmax version of SNLDS in our experiments. The continuous relaxation approach was applied to SLDS models in (Becker-Ehmck et al., 2019) and HSMM models in (Liu et al., 2018a; Kipf et al., 2019). However, the relaxation can lose many of the benefits of having discrete variables (Le et al., 2019). Relaxing the distribution to a soft mixture of dynamics results in the Kalman VAE (KVAE) model of Fraccaro et al. (2017). We will compare to KVAE in our experiments. A concern is that soft models may use a mixture of dynamics for distinct ground truth states rather than assigning a distinct mode of dynamics at each step as a discrete model must do. In Section 3, we propose a new method to avoid these issues, in which we collapse out so that the entire model is differentiable.

The SVAE model of Johnson et al. (2016) also uses the forward-backward algorithm to compute ; however, they assume the dynamics of are linear Gaussian, so they can apply the Kalman smoother to compute . Assuming linear dynamics can result in over-segmentation, as we have discussed. A forward-backward algorithm is applied once to the discrete states and once to the continuous states to compute a structured mean field posterior . In contrast, we perform approximate inference for using one forward-backward pass of a non-linear network and then exact inference for using a second pass, as we explain in Section 3.

2.3 Monte Carlo inference

There is a large literature on using sequential Monte Carlo methods for inference in state space models as particle filters (see e.g., Doucet and Johansen (2011)). When the model is nonlinear (as in our case), we may need many particles to get a good approximation, which can be expensive. We can often get better (lower variance) approximations by analytically marginalizing out some of the latent variables; the resulting method is called a “Rao Blackwellised particle filter” (RBPF).

Prior work (e.g., Doucet et al. (2001)) has applied RBPF to SLDS models, leveraging the fact that it is possible to marginalize out using the Kalman filter. It is also possible to compute the optimal proposal distribution for sampling from in this case. However, this relies on the model being conditionally linear Gaussian. In contrast, we marginalize out , so we can handle nonlinear models. In this case, it is hard to compute the optimal proposal distribution for sampling from , so instead we use variational inference to learn to approximate this.

3 Method

3.1 Inference

We use the following variational posterior: where is the exact posterior (under the generative model) computed using the forward-backward algorithm, and is defined below. To compute , we first process through a bidirectional RNN, whose state at time is denoted by . We then use a forward (causal) RNN, whose state denoted by , to compute the parameters of , where the hidden state is computed based on and . This gives the following approximate posterior: See fig. 2(b) for an illustration.

We can draw a sample sequentially, and then treat this as “soft evidence” for the HMM model. We can use a forward-backward algorithm to integrate out the discrete variables and compute gradients as Eqn. 8. This approach offers a great amount of modeling flexibility. The only constraints are that is differentiable and that the discrete variables can be integrated out of to also make it differentiable. The continuous transition dynamics can be linear, a simple non-linear kernel function, or a complicated function parameterized as an artificial neural network or RNN. The discrete transitions can depend on observed data, control signals, or the soft evidence samples, . The flexibility of this formulation allows it to cover the model families of multiple prior works Johnson et al. (2016); Linderman et al. (2017); Chow and Zhang (2013); Doucet et al. (2000) with a single core algorithm.

3.2 Learning

The evidence lower bound (ELBO) for a single sequence is given by


Because is reparameterizable, we can approximate the gradient as follows:


where is a sample from the variational proposal The second term can be computed by applying backpropagation through time to the inference RNN. In the appendix, we show that the first term is given by



3.3 Entropy regularization and temperature annealing

When using expressive nonlinear functions (e.g. an RNN or MLP) to model , we found that the model only used a single discrete state, analogous to posterior collpase in VAEs (see e.g., Alemi et al. (2018)). The forward-backward algorithm causes this behavior because low-probability states are never improved. Prior work, such as (Linderman et al., 2017), solves this problem by multi-step pretraining to ensure the model is well initialized. To encourage the model to utilize multiple states, we add an additional regularizing term to the ELBO that penalizes the KL divergence between the state posterior at each time step and a uniform prior  (Burke et al., 2019). We call this a cross-entropy regularizer:


Our overall objective now becomes


where is a scaling factor. To further smooth the optimization problem, we apply temperature annealing to the discrete state transitions, as follows: where is the temperature.

At the beginning stage of training, are set to large values. Doing so ensures that all states are visited, and can explain the data well. Over time, we reduce the regularizers to and temperature to , according to a fixed annealing schedule. Initially, the regularization induces correlated dynamics because each state needs to be used, but annealing allows the dynamics to decorrelate (See  Section A.6 and c.f., Rose (1998)). The result is similar to multi-step pretraining but our approach works in a continuous end-to-end fashion.

4 Experiments

In this section, we compare our method to various other methods that have been recently proposed for time series segmentation using latent variable models. Since it is hard to evaluate segmentation without labels, we use three synthetic datasets, where we know the ground truth, for quantitative evaluation but we also qualitatively evaluate the segmentation on a real world dataset.

In each case, we fit the model to the data, and then estimate the most likely hidden, discrete state at each time step, . Since the model is unidentifiable, the state labels have no meaning, so we post-process them by selecting the permutation over labels that maximizes the score across frames. The score is the harmonic mean of precision and recall, , where is the percentage of the predictions that match the ground truth states, and is the percentage of the ground truth states that match the predictions. We also compute the switching-point by only considering the frames where the ground truth state changes. This measure compliments the frame-wise , because it measures temporal specificity.

4.1 1d bouncing ball

DATASET Bouncing Ball Reacher Task
METRIC (S.P.) (F.W.) (S.P.) (F.W.)
SLDS (Ours) 100. 100. 59.6 3.2 81.0 3.4
rSLDS 100. 100.
SVAE 100. 100.
KVAE 100. 100.
SNLDS (Ours) 100. 100. 78.1 4.2 89.0 2.0
Gumbel-Softmax SNLDS
CompILE - - -
Table 1: Quantitative comparisons (in % ) for segmentation on bouncing ball and reacher task. We report the scores in percentage with mean and standard deviation over 5 runs. (S.P. for switching point, F.W. for frame-wise, the best mean is in bold.) The score for CompILE is adapted from Kipf et al. (2019), where only switching point score is provided. The score for KVAE is computed based on taking ‘argmax’ on the ‘dynamics parameter network’ as described in Fraccaro et al. (2017).
Figure 3: SNLDS Segmentation on bouncing ball task with an RNN continuous transition function. Top: illustration of input sequence and reconstruction. Center (green): ground truth of the latent discrete states that correspond to the two directions of motion. Center (blue): the posterior marginals of of SNLDS at , , and training steps, where lighter colors represent higher likelihood. Bottom: Training progress of the log relative negative log-likelihood (Orange) and frame-wise F1 score (Blue) for SNLDS. Log relative negative log-likelihood is calculated as , where is negative log-likelihood. The scale emphasizes that the loss still improves even late during training.

In this section, we use a simple dataset from Johnson et al. (2016). The data encodes the location of a ball bouncing between two walls in a one dimensional space. The initial position and velocity are random, but the wall locations are constant.

We apply our SNLDS model to this data, where and are both MLPs. We found that regularization was not necessary in this experiment. We also consider the case where and are linear (i.e. an SLDS model), the rSLDS model of Linderman et al. (2017), the SVAE model of Johnson et al. (2016), the Kalman VAE (KVAE) model of Fraccaro et al. (2017) and a Gumbel-Softmax version of SNLDS as described in Section A.2. We use the implementations of rSLDS, SVAE, and KVAE provided by the authors.

All models we tested learn a perfect segmentation, as shown in Figure 4(a) and Table 1. This serves as a “sanity check” that we are able to use and implement the rSLDS, SVAE, KVAE and Gumbel-Softmax SNLDS code correctly. (See also Section A.3 for further analysis.)

Note that the “true” number of discrete states is just 2, encoding whether the ball is moving up or down. We find that our method can learn to ignore irrelevant discrete states if they are not needed. This is presumably because we are maximizing the marginal likelihood since we sum over all hidden states, and this is known to encourage model simplicity due to the ”Bayesian Occam’s razor” effect (Murray and Ghahramani, 2005). By contrast, we had to be more careful in setting when using the other methods.

An example of training a SNLDS model on the Bouncing Ball task is provided as Figure 3. Early in training, the discrete states do not align well to the ground truth transitions. The three states transition rapidly near one of the walls and the frame-wise F1 score is near chance values. However, by ten thousand iterations, the model has learned to ignore one state and switches between the two states corresponding to the ball bouncing from the wall. Notably the negative log-likelihood changes by over 10 orders of magnitude before the model learns accurate segmentation of even this simple problem. We hypothesize that the likelihood is dominated by errors in continuous dynamics rather than in the discrete segmentation until very late in training.

4.2 2d reacher task

In this section, we consider a dataset proposed in the CompILE paper (Kipf et al., 2019). The observations are sequences of dimensional vectors, derived from the 2d locations of various static objects, and the 2d joint locations of a moving arm (see Section A.4 for details and a visualization). The ground truth discrete state for this task is the identity of the target that the arm is currently reaching for (i.e., its ”goal”).

We fit the same 6 models as above to this dataset. It is a much harder problem that requires more expressive dynamics, and we found that we needed to add regularization to our model to encourage it to switch states. Figure 4(b) visualizes the resulting segmentation (after label permutation) for a single example. We see that our SNLDS model matches the ground truth more closely than our SLDS model, as well as the rSLDS, SVAE, KVAE, and Gumbel-Softmax baselines.

To compare performance quantitatively, we evaluate the models from different training runs on the same held-out dataset of size , and compute the scores. We also report the number from CompILE. The CompILE paper uses an iterative segmentation scheme that can detect state changes, but it does infer what the current latent state is, so we cannot include it in Figure 4(b). In Table 1, we find that our SNLDS method is significantly better than the other approaches.

Figure 4: Segmentation on bouncing ball (top) and reacher task (bottom). From top to bottom: ground truth of latent discrete states, then the posterior marginals, , of the SNLDS, SLDS, rSLDS, SVAE, KVAE, and Gumbel-Softmax SNLDS models respectively, where lighter color represents higher probability. CompILE is not included because it represents a different model family that directly predicts the segment boundary without calculating posterior marginals at each time step.

4.3 Dubins path

(Switching point, Tol 0) 11.3 5.7
(Switching point, Tol 5) 82.5 1.9
(Frame-wise) 84.3 7.2
Table 2: Quantitative comparisons (in %) for S(N)LDS on Dubins path. For SLDS, scores with both greedy -to- matching (Greedy) and optimal merging (Merging) are provided. The switching point scores are estimated with both precise matching (Tol 0) or allowing at most -step displacement (Tol 5).

In this section, we apply our method to a new dataset that is created by rendering a point moving in the 2d plane. The motion follows the Dubins model2, a simple model for piece-wise nonlinear (but smooth) motion that is commonly used in the fields of robotics and control theory because it corresponds to the shortest path between two points that can be traversed by wheeled robots, airplanes, etc. In the Dubins model, the change in direction is determined by an external control signal . We replace this with three latent discrete control states: go straight, turn left, and turn right. These correspond to fixed, but unobserved, input signals (see Section A.5 for details). After generating the motion, we create a series of images, where we render the location of the moving object as a small circle on a white background. Our goal in generating this dataset was to assess how well we can recover latent dynamics from image data in a very simple, yet somewhat realistic, setting.

The publicly released code for rSLDS and SVAE does not support high dimensional inputs like images (even though the SVAE has been applied to an image dataset in Johnson et al. (2016)), and there is no public code for CompILE. Therefore we could not compare to these methods for this experiment. As we already showed in Section 4.2 that our method is much better than these other approaches, as well as Kalman VAE and Gumbel-Softmax version of SNLDS, on other tasks, we expect the same conclusion to hold on the harder task of segmenting videos.

Instead we focus on comparing inferred SNLDS states with SLDS states to determine the advantage of allowing each regime to be represented by a nonlinear model. The results of segmenting one sequence with these models using states are shown in Figure 1. We see that the SLDS model has to approximate the left and right turns with multiple discrete states, whereas the non-linear model learns a more interpretable representation.

We again compare the models in Table 2 using scores. Since matching the exact time of the switching point is very hard in the unsupervised setting with noisy observations, we also report an F1 computed with a tolerance of detecting a change within 5 frames. Because the SLDS model used too many states, we calculated two versions of the metrics. The first was a greedy approach that optimally assigned the best single state to match each ground truth state. The second used an oracle to optimally merge states to match the ground truth. The SNLDS model significantly outperforms the SLDS in both scenarios.

4.4 Salsa Dancing from CMU MoCap

In this section, we demonstrate the capacity of SNLDS on segmenting 3D human pose dynamics on CMU MoCap data 3. There are trials of Salsa dancing sequences in the dataset. We use of them as the training data, and hold out the other for evaluation. The training sequences are generated by down-sampling the original sequences using every 6 frames. The input to the model consists of cordinates of joints. Using MLP to describe the nonlinear transition of continuous hidden states, SNLDS can segment sequences into modes of primitive motions, which could be interpreted as: turning clockwise, turning counter-clockwise, and translational motion. Without ground truth segmentation, we only evaluate the segmentation qualitatively, as shown in the Figure 5.

Figure 5: SNLDS segmentation result for Salsa dancing trial in CMU MoCap dataset. The model segment the motion into three different dynamical modes: moving forward and backward (orange colored), clockwise turning (magenta) and counter-clockwise turning (green). The center depicts the posterior marginal for each state and the boxes show samples of motion from each state.

4.5 Analysis of the annealing schedule

Many latent variable models are trained in multiple stages to avoid getting stuck in bad local optima. For example, to fit the rSLDS model, Linderman et al. (2017) first pretrain an AR-HMM and SLDS model, and then merge them; similarly, to fit the SVAE model, Johnson et al. (2016) first train with a single latent state and then increase .

We found a similar strategy was necessary for the Reacher, Dubins, and Salsa tasks, but we do this in a smooth way using annealed regularization. Early in training, we train with large temperature and entropy coefficient . This encourages the model to use all states equally, so that the dynamics, inference, and emission sub-networks stabilized before beginning to learn specialized behavior. We then anneal the entropy coefficient to , and the temperature to over time. We found it best to first decay the entropy coefficient and then decay the temperature .

Figure 6: Comparing the relative negative log-likelihood (top) and the frame-wise scores (bottom) on Dubins paths with 3 different annealing schedules. In the first run (green), the regularization coefficient and temperature start to decay at the very beginning of training. In the second run (red), the cross entropy regularization coefficient starts to decay at step , while temperature annealing starts at step . In the third run (blue), the coefficient decay starts at step , while temperature annealing starts at step .

Figure 6 demonstrates the effect of different annealing schedules on the relative log likelihood (defined as , where across all three runs, and is the negative log-likelihood.), and the score. We find that the final negative log-likelihood and scores improve when we delay the annealing schedule to steps on the Dubins task. Surprisngly, the score does not improve significantly until an additional steps after the temperature begins annealing. On real problems, where we have no ground truth, we cannot use the score as a metric to determine the best annealing schedule. However, it seems that the schedules that improve the most also improve likelihood the most.

5 Conclusion

We have demonstrated that our proposed method can effectively learn to segment high dimensional sequences into meaningful discrete regimes. Future work includes applying this to harder image sequences and to hierarchical reinforcement learning.

Appendix A Appendix

a.1 Derivation of the gradient of the ELBO

The evidence lower bound objective (ELBO) of the model is defined as:


where the first term is the model likelihood, and the second is the conditional entropy for variational posterior of continuous hidden states. We can approximate the entropy of as:

where is a sample from the variational posterior. In other words, we compute the marginal entropy for the output of the RNN inference network at each time step, and then sample a single latent vector to update the RNN state for the next step.

In order to apply stochastic gradient descent for end-to-end training, the minibatch gradient for the first term in the ELBO (Eq. 11) with respect to is estimated as

For the gradient with respect to , we can use the reparameterization trick to write

Therefore, the gradient is expressed as:

To compute the derivative of the log-joint likelihood , where we define as the visible variables for brevity. Therefore

where we used the fact that and

For , we use the Markov property to rewrite it as:

with the expectation being:

Therefore we reach the Eq. 8.

In summary, one step of stochastic gradient ascent for the ELBO can be implemented as Algorithm 1.

1:  Compute from using a Bi-RNN;
2:  Recursively sample using RNN over and ;
3:  Run forward-backward messages to compute , , , , from ;
4:  Compute from Eqn. 8;
5:  Take gradient step.
Algorithm 1 SVI for Training SNLDS

a.2 Gumbel-Softmax SNLDS

Instead of marginalizing out the discrete states with the forward-backward algorithm, one could use a continuous relaxation via reparameterization, e.g. the Gumbel-Softmax trick (Jang et al., 2017), to infer the most likely discrete states. We call this Gumbel-Softmax SNLDS.

We consider the same state space model as SNLDS:

where is the discrete hidden state, is the continuous hidden state, and is the observed output, as in Figure 2(a). The inference network for the variational posterior now predicts both and and is defined as



where is the hidden state of a deterministic recurrent neural network, , which works from left () to right (), summarizing past stochastic . We also feed in , which is a bidirectional RNN, which summarizes . The Gumbel-Softmax distribution takes the output of a feed-forward network and a softmax temperature , which is annealed according to a fixed schedule.

The evidence lower bound (ELBO) could be written as


One step of stochastic gradient ascent for the ELBO can be implemented as Algorithm 2.

1:  Use Bi-RNN to compute from ;
2:  Recursively sample using RNN over and ;
3:  Recursively sample with distribution , where is a feedforward network;
4:  Compute the likelihood for eq. 13;
5:  Take gradient step.
Algorithm 2 SVI for Training Gumbel-Softmax SNLDS

a.3 Details on the bouncing ball experiment

The input data for bouncing ball experiment is a set of sample trajectories, each of which is of timesteps with its initial position randomly placed between two walls separated by a distance of . The velocity of the ball for each sample trajectory is sampled from . The exact position of ball is obscured with Gaussian noise . The training is performed with batch size . The evaluation is carried on a fixed, held-out subset of the data with samples. For the inference network, the bi-directional and forward RNNs are both dimensional GRU. The dimensions of discrete and continuous hidden state are set to be and . For SLDS, we use linear transition for continuous states. For SNLDS, we use GRU with hidden units followed by linear transformation for continuous state transition. The model is trained with fixed learning rate of , with the Adam optimizer (Kingma and Ba, 2015), and gradient clipping by norm of for steps.

a.4 Details on the reacher experiment

Figure 7: Illustration of the observations in reacher experiment. This is -D rendering of the observational vector, but the inputs to the model are sequences of vectors, as in Kipf et al. (2019), not images.

The observations in the reacher experiment are sequences of dimensional vectors, as described in Kipf et al. (2019). First elements are the target indicator, , and location, , for 10 randomly generated objects. out of objects start as targets, . The location for 5 of the non-target objects are set to . A deterministic controller moves the arm to the indicated target objects. Once a target is reached, the indicator is set to . (Depicted as the yellow dot disappearing in Figure 7.) The remaining elements of the observations are the two angles of reacher arm and the positions of two arm segment tips. The training dataset consists of observation samples, each timesteps in length.

This more complex task requires more careful training. The learning rate schedule is a linear warm-up, to over steps, from followed by a cosine decay, with decay rate of and minimum of . Both entropy regularization coefficient starts to exponentially decay after steps, from initial value with a decay rate and decay steps . The temperature annealing follows the same exponential but only starts to decay after steps. The training is performed in minibatches of size for iterations using the Adam optimizer (Kingma and Ba, 2015).

The model architecture is relatively generic. The continuous hidden state is dimensional. The number of discrete hidden states is set to for training, which is larger than the ground truth (including states targeting objects and a finished state). The observations pass through an encoding network with two -unit ReLU activated fully-connected nets, before feeding into RNN inference networks to estimate the posterior distributions . The RNN inference networks consist of a -unit bidirection LSTM and a -unit forward LSTM. The emission network is a three-layer MLP with hidden units and ReLU activation for first two layers and a linear output layer. Discrete hidden state transition network takes two inputs: the previous discrete state and the processed observations. The observations are processed by the encoding network and a -D convolution with kernels of size . The transition network outputs a matrix for transition probability at each timestep. For SNLDS, we use a single-layer MLP as the continuous hidden state transition functions , with hidden units and ReLU activation. For SLDS, we use linear transitions for the continuous state.

a.5 Details on the Dubins path experiment

The Dubins path model4 is a simplified flight, or vehicle, trajectory that is the shortest path to reach a target position, given the initial position , the direction of motion , the speed constant , and the maximum curvature constraint . The possible motion along the path is defined by

The path type can be described by three different modes/regimes: ‘right turn (R)’ , ‘left turn (L)’ or ‘straight (S).’

To generate a sample trajectory used in training or testing, we randomly sample the velocity from a uniform distribution (pixel/step), angular frequency from a uniform distribution (/step), and initial direction . The generated trajectories always start from the center of image . The duration of each regime is sampled from a Poisson distribution with mean steps, with full sequence length steps. The floating-point positional information is rendered onto a image with Gaussian blurring with standard deviation to minimize aliasing.

The same schedules as in the reacher experiment are used for the learning rate, temperature annealing and regularization coefficient decay.

The network architecture is similar to the reacher task except for the encoder and decoder networks. Each observation is encoded with a CoordConv (Liu et al., 2018b) network before passing into RNN inference networks, the archicture is defined in Table 3. The emission network also uses a CoordConv network as described in Table 4. The continuous hidden state in this experiment is dimensional. The number of discrete hidden states is set to be , which is larger than ground truth . The inference networks are a -unit bidirection LSTM and a -unit forward LSTM. The discrete hidden state transition network takes the output of observation encoding network in the same manner as the reacher task. For SNLDS, we use a two-layer MLP as continuous hidden state transition function , with hidden units and ReLU activation. For SLDS, we use linear transition for continuous states.

Layer Filters Shape Activation Stride Padding
1 2 [5, 5] relu 1 same
2 4 [5, 5] relu 2 same
3 4 [5, 5] relu 1 same
4 8 [5, 5] relu 2 same
5 8 [7, 7] relu 1 valid
6 8 2 (Kernel Size) None 1 causal
Table 3: CoordConv encoder Architecture. Before passing into the following network, the image is padded from to with the pixel coordinates.
Layer Filters Shape Activation Stride Padding
1 14 [1, 1] relu 1 valid
2 14 [1, 1] relu 1 valid
3 28 [1, 1] relu 1 valid
4 28 [1, 1] relu 1 valid
5 1 [1, 1] relu 1 same
Table 4: CoordConv decoder Architecture. Before passing into the following network, the input is tiled from to , where is the hidden dimension, and is then padded to with the pixel coordinates.

See Figure 8 for an illustration of the reconstruction abilities (of the observed images) for the SLDS and SNLDS models. They are visually very similar; however, the SNLDS has a more interpretable latent state as described in Section 4.3.

Figure 8: Image sequence reconstruction for Dubins path. The sequence is averaged with early timepoints scaled to low intensity, late timepoints unchanged to indicate direction.

a.6 Regularization and Multi-steps Training

Figure 9: Comparing the average Pearson correlations among the weights from individual dynamical transition modes, , trained on Dubins Paths. Run (green) is trained without regularization. Run (blue) has its entropy coefficient starting to exponentially decay at step , and the temperature starting to anneal at step .

Training our SNLDS model with a powerful transition network but without regularization will fit the dynamics with a single state. With randomly initialized networks, one state fits the dynamics better at the beginning and the forward-backward algorithm will cause more gradients to flow through that state than others. The best state is the only one that gets better.

To prevent this, we use regularization to cause the model to select each mode with uniform likelihood until the inference and emission network are well trained. Thus all discrete modes are able to learn the dynamics well initially. When the regularization decays, the transition dynamics of each mode can then specialize. One effect of this regularization strategy is that the weights for each dynamics module are correlated early during training and decorrelate when the regularization decays. The regularization helps the model to better utilize its capacity, and the model can achieve better likelihood, as demonstrated in Section 4.5 and Figure 6.

Multi-steps training has been used by previous models, and it serves the same purpose as our regularization. SVAE first trains a single transition model, then uses that one set of parameters to initialize all the transition dynamics for multiple states in next stage of training. rSLDS training begins by fitting a single AR-HMM for initialization, then fits a standard SLDS, before finally fitting the rSLDS model. We follow these implementations of both SVAE and rSLDS in our paper. Both multi-step training and our regularization ensure the hidden dynamics are well learned before learning the segmentation. What makes our regularization approach interesting is that it allows the model to be trained with a smooth transition between early and late training.


  1. In the case of sequential models, we can create tighter lower bounds using methods such as FIVO (Maddison et al., 2017a), although this is orthogonal to our work.


  1. Fixing a broken ELBO. In Intl. Conf. on Machine Learning (ICML), External Links: Link Cited by: §1, §3.3.
  2. Switching linear dynamics for variational Bayes filtering. In Intl. Conf. on Machine Learning (ICML), External Links: Link Cited by: §2.2.
  3. Variational inference: a review for statisticians. J. of the Am. Stat. Assoc. (JASA). External Links: Link Cited by: §2.2.
  4. Hybrid system identification using switching density networks. CoRR abs/1907.04360. External Links: Link Cited by: §3.3.
  5. Explicit-duration Markov switching models. Foundations and Trends in Machine Learning 7 (6). External Links: Link Cited by: §2.1.
  6. Nonlinear regime-switching state-space (RSSS) models. Psychometrika 78 (4), pp. 740–768. External Links: Link Cited by: §2.1, §3.1.
  7. Recurrent hidden semi-Markov model. In Intl. Conf. on Learning Representations (ICLR), External Links: Link Cited by: §2.1.
  8. Rao-Blackwellised particle filtering for dynamic Bayesian networks. In Proc. of the Conf. on Uncertainty in AI (UAI), External Links: Link Cited by: §1, §3.1.
  9. Particle filters for state estimation of jump Markov linear systems. IEEE Trans. on Signal Processing 49 (3), pp. 613–624. Cited by: §2.3.
  10. A tutorial on particle filtering and smoothing: fifteen years later. In The Oxford Handbook of Nonlinear Filtering, D. Crisan and B. Rozovsk (Eds.), External Links: Link Cited by: §2.3.
  11. Activity recognition and abnormality detection with the switching hidden semi-Markov model. In Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR), External Links: Link Cited by: §2.1.
  12. A disentangled recognition and nonlinear dynamics model for unsupervised learning. In Advances in Neural Info. Proc. Systems (NIPS), External Links: Link Cited by: §2.2, §4.1, Table 1.
  13. Categorical reparameterization with gumbel-softmax. In Intl. Conf. on Learning Representations (ICLR), External Links: Link Cited by: §A.2, §2.2.
  14. Structured vaes: composing graphical models with neural networks for structured representations and fast inference. In Advances in Neural Info. Proc. Systems (NIPS), External Links: Link Cited by: §1, §2.1, §2.2, §3.1, §4.1, §4.1, §4.3, §4.5.
  15. Adam: A method for stochastic optimization. In Intl. Conf. on Learning Representations (ICLR), External Links: Link Cited by: §A.3, §A.4.
  16. Auto-encoding variational Bayes. In Intl. Conf. on Learning Representations (ICLR), External Links: Link Cited by: §1, §2.2.
  17. CompILE: compositional imitation learning and execution. In Intl. Conf. on Machine Learning (ICML), External Links: Link Cited by: Figure 7, §A.4, §1, §2.1, §2.2, §4.2, Table 1.
  18. Revisiting reweighted Wake-Sleep for models with stochastic control flow. In Proc. of the Conf. on Uncertainty in AI (UAI), External Links: Link Cited by: §2.2.
  19. Bayesian learning and inference in recurrent switching linear dynamical systems. In Conf. on AI and Statistics (AISTATS), External Links: Link Cited by: §2.1, §3.1, §3.3, §4.1, §4.5.
  20. Hierarchical recurrent state space models reveal discrete and continuous dynamics of neural activity in c. elegans. In biorxiv, External Links: Link Cited by: §1.
  21. Structure-exploiting variational inference for recurrent switching linear dynamical systems. In IEEE 7th International Workshop on Computational Advances in Multi-Sensor Adaptive Processing, (CAMSAP), Cited by: §2.1.
  22. Structured inference for recurrent hidden semi-Markov model. In Intl. Joint Conf. on AI (IJCAI), pp. 2447–2453. External Links: Link Cited by: §2.2.
  23. An intriguing failing of convolutional neural networks and the coordconv solution. In Advances in Neural Info. Proc. Systems (NeurIPS), Vol. abs/1807.03247. External Links: Link Cited by: §A.5.
  24. Filtering variational objectives. In Advances in Neural Info. Proc. Systems (NIPS), External Links: Link Cited by: footnote 1.
  25. The concrete distribution: A continuous relaxation of discrete random variables. In Intl. Conf. on Learning Representations (ICLR), External Links: Link Cited by: §2.2.
  26. A note on the evidence and Bayesian occam’s razor. Technical report Gatsby Computational Neuroscience Unit, University College London. External Links: Link Cited by: §4.1.
  27. Stochastic backpropagation and approximate inference in deep generative models. In Intl. Conf. on Machine Learning (ICML), External Links: Link Cited by: §1.
  28. Deterministic annealing for clustering, compression, classification, regression, and related optimization problems. Proc. IEEE 80, pp. 2210–2239. External Links: Link Cited by: §3.3.
  29. Point process latent variable models of larval zebrafish behavior. In Advances in Neural Info. Proc. Systems (NeurIPS), Cited by: §1.
  30. REBAR: low-variance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Info. Proc. Systems (NIPS), External Links: Link Cited by: §2.2.
  31. Mapping sub-second structure in mouse behavior. Neuron 88 (6), pp. 1121–1135. External Links: Link Cited by: §1.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description