Consistent Generative Query Networks
Stochastic video prediction is usually framed as an extrapolation problem where the goal is to sample a sequence of consecutive future image frames conditioned on a sequence of observed past frames. For the most part, algorithms for this task generate future video frames sequentially in an autoregressive fashion, which is slow and requires the input and output to be consecutive. We introduce a model that overcomes these drawbacks – it learns to generate a global latent representation from an arbitrary set of frames within a video. This representation can then be used to simultaneously and efficiently sample any number of temporally consistent frames at arbitrary time-points in the video. We apply our model to synthetic video prediction tasks and achieve results that are comparable to state-of-the-art video prediction models. In addition, we demonstrate the flexibility of our model by applying it to 3D scene reconstruction where we condition on location instead of time. To the best of our knowledge, our model is the first to provide flexible and coherent prediction on stochastic video datasets, as well as consistent 3D scene samples. Please check the project website https://bit.ly/2jX7Vyu to view scene reconstructions and videos produced by our model.
The ability to fill in the gaps in high-dimensional data is a fundamental cognitive skill. Suppose you glance out of the window and see a person in uniform approaching your gate carrying a letter. You can easily imagine what will (probably) happen next. The person will walk up the path and push the letter through your door. Now suppose you glance out of the window the following day and see a person in the same uniform walking down the path, away from the house. You can easily imagine what (probably) just happened. The person came through your gate, walked up the path, and delivered a letter. Moreover, in both instances, you can visualize the scene from different viewpoints. From your vantage point at the window, you can imagine how things might look from the gate, or from the front door, or even from your neighbour’s roof. Essentially, you have learned from your past experience of unfolding visual scenes how to extrapolate and interpolate in both time and space.
Replicating this ability is a significant challenge for artificial intelligence. Building on the recently developed Generative Query Network (GQN) framework (Eslami et al., 2018), we here present a neural network architecture that learns models that can flexibly extrapolate in the visual domain. Moreover, as well as models operating in time, our method can learn models that operate in space. To achieve this, we had to overcome a significant difficulty that does not arise with autoregressive models, namely the need to generate consistent sets of samples for a given temporal context even when the problem is non-deterministic. Similarly, when conditioned on camera position, our models can sample consistent sets of images for an occluded region of a scene, even when there are multiple possibilities for what that region might contain.
To make this more precise, let’s first consider video prediction, a task that has been widely studied in machine learning and computer vision. At its core any video prediction task involves learning a model that can sample a set of future frames conditioned on a sequence of previous frames . State-of-the-art models often carry out the prediction sequentially in an autoregressive manner by sampling the frame from . While autoregressive models are able to generate accurate sequences of predicted frames, they are usually restricted with regards to the structure and order of their inputs – the input and output frames must be consecutive. Further, their autoregressive one-frame-at-a-time nature renders them slow to sample from.
Instead, we cast the video prediction problem as a “query” task. Generative Query Networks (GQNs) (Eslami et al., 2018) are spatial prediction models that are “queried” at test time given a set of conditioning input pairs. In its original setting, a trained GQN is given frames from a single 3D scene together with camera positions from which those frames were rendered. These input pairs are referred to as the context . The model is then asked to sample a plausible frame rendered from an arbitrary position . In this work, we apply this framework to temporal as well as spatial data. When we apply this framework to video prediction, the context corresponds to the pairs of individual frames and the timestamps at which they occur. The query contains the timestamps of the frames that we want to sample, .
In many cases there are multiple possible future frames that are consistent with the context, making the problem stochastic. For example, a car moving towards an intersection could turn left or turn right. Given the context (timestamps and frames) of the car moving to the intersection, and a single query timestamp, GQN can sample a plausible frame at the specified time. However, unlike autoregressive models, each frame is sampled independently. GQN will sometimes sample a frame of the car on the left and sometimes sample a frame of the car on the right, but cannot capture a coherent sequence of frames where the car turns left. To address these issues, we introduce Consistent Generative Query Networks (CGQN). Given the context, CGQN samples a stochastic latent scene representation that models the global stochasticity in the scene. Given an arbitrary query, the model uses the sampled scene representation to render a frame corresponding to the query. The model captures correlations over multiple target frames.
To test CGQN we develop two synthetic datasets. Our first dataset consists of procedurally generated 2D shapes enacting a non-deterministic narrative of events. The goal of the prediction model is to reconstruct the narrative from a few random context frames. Quantitatively, the predictions generated by our model match those obtained by state-of-the-art video prediction algorithms, while being flexible with regards to the input and output structures as well as being able to generate several frames at once. To showcase the flexibility of our model we also apply it to consistent 3D scene reconstruction. To this end we introduce an additional dataset that consists of images of 3D scenes containing a cube with random MNIST digits engraved on each of its faces. We show quantitatively and qualitatively that CGQN outperforms GQN on this dataset, as GQN is unable to capture several correlated frames of occluded regions of the scene. We strongly encourage the reader to check the project website https://bit.ly/2jX7Vyu to view actual videos of experiments.
2 Related Work
Generative Query Networks (GQN) Our model builds on GQN (Eslami et al., 2018), a conditional DRAW (Gregor et al., 2015, 2016) network used for spatial prediction. The architectural changes we add to facilitate consistency include the rendering network , reconstructing multiple target frames, giving the posterior network an encoding of multiple targets instead of a single target frame, and using a global latent to capture stochastic elements of the scene. In addition, GQN was used for (primarily deterministic) spatial prediction. We cast stochastic video prediction as a similar querying task and highlight its advantages. Concurrent work on neural processes (Garnelo et al., 2018), which extend GQN, also have architectural similarities to CGQN. Neural processes are tested on 2D function regression, 2D Thomson Sampling, and 2D contextual bandit problems. CGQN’s architecture and evaluation focuses on high dimensional datasets like video prediction and scene prediction.
Video Prediction A common limitation of video prediction models is their need for determinism in the underlying environments (Lotter et al., 2017; Srivastava et al., 2015; Boots et al., 2014; Finn et al., 2016; Liu et al., 2017). Creating models that can work with stochastic environments is the motivation behind numerous recent papers: On one end of the complexity spectrum there are models like Video Pixel Networks (Kalchbrenner et al., 2017) and its variants (Reed et al., 2017b) that are powerful but computationally expensive models. These models generate a video frame-by-frame, and generate each frame pixel-by-pixel. SV2P (Babaeizadeh et al., 2018) does not model each individual pixel, but autoregressively generates a video frame-by-frame. On the other end of the spectrum, sSSM (Buesing et al., 2018) is a faster model that generates an abstract state at each time step which can be decoded into a frame when required. All these stochastic models still generate frames (or states) one time-step at a time and the input and output frames must be consecutive. By bypassing these two issues CGQN extends this spectrum of models as a flexible and even faster stochastic model. Additionally, unlike prior methods, our models can be used for a wider range of tasks – we demonstrate our approach on non-deterministic 3D scene reconstruction.
Meta-Learning Finally, our task can be framed as a few-shot density estimation problem and is thus related to some ongoing research in the area of meta learning. Meta learning is often associated with few-shot classification tasks (Vinyals et al., 2016; Koch et al., 2015) but these algorithms can be extended to few-shot density estimation for image generation (Bartunov and Vetrov, 2018). Additional approaches include models with variational memory (Bornschein et al., 2017), attention (Reed et al., 2017a) and conditional latent variables (J. Rezende et al., 2016). Crucially, while these models can sample the estimated density at random they cannot query it at specific target points and their application is limited to visually less challenging datasets like omniglot.
3.1 Problem Description
We consider problems where we have a collection of “scenes”. Scenes could be videos, spatial scenes, or in general any key-indexed collection. A scene consists of a collection of viewpoint-frame (key-value) pairs where refers to the indexing ‘viewpoint’ information and to the frame. For videos the ‘viewpoints’ are timestamps. For spatial scenes the ‘viewpoints’ are camera positions and headings. For notational simplicity, we assume that each scene has fixed length (but this is not a requirement of the model). The viewpoint-frame pairs in each scene are generated from a data generative process , as formulated below. Note that the viewpoint-frame pairs are typically not independent.
Each scene is split into a context and a target. The context contains viewpoint-frame pairs . The target contains the remaining viewpoints and corresponding target frames . At evaluation time, the model receives the context and target viewpoints and should be able to sample possible values corresponding to the viewpoints . In particular, the model parameterizes a (possibly implicit) conditional distribution , from which the frames are sampled.
Given a training set of example scenes from data distribution , the training objective is to find model parameters that maximize the log probability of the data
3.2 Model (Generation)
We implement CGQN as a latent variable model. For each scene , CGQN encodes the viewpoint-frame pairs of the context by applying a representation function to each pair independently. The resulting representations are aggregated in a permutation-invariant way to obtain a single representation . The latent variable is then sampled from a prior that is conditioned on this aggregated representation . The idea behind is that it can capture global dependencies across the target viewpoints, which is crucial to ensure that the output frames are generated from a single consistent plausible scene. For each corresponding target viewpoint , the model applies a deterministic rendering network to and to get an output frame . Our model, CGQN, can thus be summarized as follows.
In CGQN, the representation network is implemented as a convolutional network. The latent is sampled using a convolutional DRAW (Gregor et al., 2015, 2016) prior, an LSTM-like model that recurrently samples latents over several iterations. The rendering network is implemented as an LSTM where the inputs and are fed in at each recurrent step. We give details of these implementations in the Appendix. Note that these building blocks are easily customizable. For example, DRAW can be replaced with a regular variational autoencoder, albeit at a cost to the visual fidelity of the generated samples.
3.3 Model (Training)
We wish to find parameters that maximize the log probability of the data under our model:
Since this optimization problem is intractable we introduce an approximate posterior . We train the model by maximizing the evidence lower bound (ELBO), as in (Rezende et al., 2014; Kingma and Welling, 2014), see derivation details in the Appendix. The resulting formulation for the ELBO of our model is as follows.
Note that this expression is composed of two terms: the reconstruction probability and the KL-divergence from the the approximate posterior to the conditional prior . We use the reparameterization trick to propagate gradients through the reconstruction probability. As we are considering Gaussian probability distributions, we compute the KL in closed form. For training, to ensure that our model’s likelihood has support everywhere, we add zero-mean, fixed variance Gaussian noise to the output frame of our model. This variance (referred to as the pixel-variance) is annealed down during the course of training.
We evaluate CGQN against a number of strong existing baselines on two tasks: a synthetic, combinatorial video prediction task and a 3D scene reconstruction task.
4.1 Video Prediction
In video prediction, given the first frames of a video, the model’s goal is to sample plausible frames that follow: . We encourage the reader to view actual videos of our experiments at https://bit.ly/2jX7Vyu.
Narratives Dataset: We present quantitative and qualitative results on a set of synthetic datasets that we call “narrative” datasets. Each dataset is parameterized by a “narrative” which describes how a collection of shapes interact. For example, in the “Traveling Salesman” narrative (Figure 3), one shape sequentially moves to (and “visits”) 4 other shapes. A dataset consists of many videos which represent different instances of a single narrative. In each instance of a narrative, each shape has a randomly selected color (out of 12 colors), size (out of 4 sizes), and shape (out of 4 shapes), and is randomly positioned. While these datasets are not realistic, they are a useful testbed. In our Traveling Salesman narrative with 5 shapes, the number of distinct instances (ignoring position) is over 260 billion. With random positioning of objects, the real number of instances is higher.
Training CGQN: For CGQN, we associate frame with corresponding timestamp . For training, the model is given between and frames (and corresponding timestamps) randomly selected out of the first frames, and tasked to predict randomly selected frames. When evaluating metrics, the model is given the first frames, and tasked to predict all frames in the video.
|F1 Seen||F6 Seen||F11 Unseen||F16 Unseen||F21 Unseen||F1 Seen||F6 Seen||F11 Unseen||F16 Unseen||F21 Unseen|
GQN vs CGQN: We qualitatively show the difference between CGQN and GQN. As described in the introduction, GQN samples each frame independently. The frames sampled therefore do not form a coherent video, because each frame is from a different possible continuation. We illustrate this in Figure 3. In this Traveling Salesman narrative, an actor shape (in Figure 3, the green square) sequentially visits four other shapes. The context includes the first 6 frames, where the actor visits the first shape. The actor then visits the other three shapes in a random order. GQN, however, cannot capture a coherent video. In the GQN sample in Figure 3, the green square never visits the white pentagon. In the project website, https://bit.ly/2jX7Vyu, we also show examples of a narrative we call “DAG Chase”, which has 4 shapes all of which move during the narrative.
Flexibility of CGQN: CGQN is more flexible than existing video prediction models. When trained with arbitrary contexts, it can take arbitrary sets of frames as input, and directly predict arbitrary sets of output frames. We illustrate this in Figure 4. In this “Color Reaction” narrative, shape 1 moves to shape 2 over frames 1 - 6, shape 2 changes color and stays that color from frames 7 - 12. Figure 4A shows the ground truth narrative. CGQN can take two frames at the start of the video, and sample frames at the end of the video (as in Figures 4B, 4C). Alternatively, CGQN can go “backwards” and take two frames at the end of the video, and sample frames at the start of the video (as in Figures 4D, 4E).
Quantitative Comparisons: We quantitatively compare CGQN with sSSM (Buesing et al., 2018), SV2P (Babaeizadeh et al., 2018), and CDNA (Finn et al., 2016) on the Traveling Salesman narrative dataset, where one shape sequentially moves to (and “visits”) 4 other shapes. The training set contains 98K examples, and the test set contains 1K examples. To evaluate each model, we take 30 sample continuations for each video and compute the minimum mean squared error between the samples and the original video. This measures that the model can (in a reasonable number of samples) sample the true video. A model that exhibits mode collapse would fare poorly on this metric because it would often fail to produce a sample close to the original video.
We swept over model size, learning rate, and stochasticity parameters for the models (see the Appendix for more details). We selected the best hyperparameter configuration, and ran the model with that hyperparameter configuration with 15 random seeds (10 for CDNA). We ran all models for 3 days using distributed ADAM on 4 Nvidia K80 GPUs. For sSSM, we discarded runs where the KL loss became too low or too high (these runs had very bad metric scores), and for SV2P we discarded a run which had especially poor metric scores. This discarding was done to the benefit of SV2P and sSSM – for CGQN and CDNA we used all runs. The runs for sSSM and SV2P had high variance, so we also compared the best 4/15 runs for sSSM, SV2P, CGQN and 3/10 runs for CDNA. The plots, with error bars of times the standard error of the mean, are shown in Figure 5.
We draw three main conclusions from the plots.
Figure 4(b) shows that our dataset is able to distinguish between the performance of video prediction models. CDNA, a deterministic model, does much worse than the other three models. SV2P, a dedicated stochastic video prediction model, performs the best.
Figure 4(a) shows that our model converges very reliably. Averaged across all runs, our model performs significantly better than sSSM, SV2P, and CDNA.
Figure 4(b) shows that our model is competitive with modern video prediction models. Our model is best compared with sSSM. SV2P is a much slower model that generates each output frame sequentially. sSSM is a faster, state-space model that bypasses generating intermediate output frames (but still does not have the flexibility of our model).
4.2 Scene Reconstruction
Our model is also capable of consistent 3D scene reconstruction. In this setup, CGQN is provided with context frames from a single 3D scene together with camera positions from which those frames were rendered. The model is then asked to sample plausible frames rendered from a set of arbitrary camera positions . Often the model is asked to sample frames in an occluded region of the scene, where there are multiple possibilities for what the region might contain. Even in the presence of this uncertainty, CGQN is able to sample consistent frames that form a coherent scene. We encourage the reader to view videos visualizations of our experiments at https://bit.ly/2jX7Vyu.
MNIST Dice Dataset To demonstrate this, we develop a 3D dataset where each scene consists of a cube in a room. Each face of the cube has a random MNIST digit (out of 100 digits) engraved on it. In each scene, the cube is randomly positioned and oriented, the color and textures of the walls are randomly selected, and the lighting source (which creates shadows) is randomly positioned. The context frames show at most three sides of the dice, but the model may be asked to sample camera snapshots that involve the unseen fourth side. We show quantitatively and qualitatively that CGQN performs better than GQN on this dataset, because GQN is unable to capture a coherent scene.
GQN vs CGQN (Qualitative) GQN samples each frame independently, and does not sample a coherent scene. We illustrate this in Figure 7, where we show an example scene from our test-set. The context frames (blue cones) see three sides of the cube, but the model is queried (red cones) to sample the occluded fourth side of the cube. Figure 7 also shows the samples for CGQN and GQN. GQN (right column) independently samples a 7 and then a 0 on the unseen side of the dice. CGQN samples a coherent scene, where the sampled unseen digit is consistently rendered across different viewpoints. CGQN’s reconstruction accurately captures the ground truth digit, which shows that the model is capable of sampling the target. Note that all frames are captured from a circle with fixed radius, with the camera facing the center of the room. However, the frames are not equally spaced, which distinguishes this from video prediction tasks.
GQN vs CGQN (Quantitative) We can compare GQN and CGQN by analyzing the test-set negative ELBO (as a proxy for the test-set negative log likelihood) over multiple target frames, each showing different viewpoints of the same unseen face of the cube. This serves as a quantitative measure for the quality of the models’ scene reconstruction. To motivate why CGQN should do better, imagine that we have a perfectly trained GQN and CGQN, which captures all the nuances of the scene. Since there are 100 possible digits engraved on the unseen side, there is a 1/100 chance that each sampled frame captures the ground truth digit on the unseen face. GQN samples the unseen digit independently for each viewpoint, so the probability that a set of three frames all capture the ground truth digit is 1/1000000. On the other hand, CGQN captures the correlations between frames. If the digit is correct in one of three frames, it should be correct in the other two frames. So the probability that a set of three frames all capture the ground truth digit is 1/100. In other words, a perfectly trained consistent model will have better log likelihoods than a perfectly trained factored model.
In practice, the benefits of consistency may trade off with accuracy of rendering the scene. For example, the more consistent model could produce lower quality images. So it is important to compare the models’ performance by comparing the test-set ELBOs. Figure 6 compares the test-set ELBOs for CGQN and GQN. We ran 4 runs for each model, and picked the best 2/4 runs for each model. The results suggest that CGQN does achieve quantitatively better scene reconstruction than GQN. We repeated this experiment twice more, with different ‘pixel-variance’ values and obtained similar plots, as shown in the Appendix.
CGQN Consistency Analysis We also analyze the consistency of CGQN. We measure the KL divergence from the posterior to the prior network in a trained CGQN model. We give CGQN a context comprising of frames that show three sides of the cube. We condition the posterior on one additional target frame that shows the unseen side of the cube, and compute the KL divergence from the posterior to the prior, . Alternatively, we condition the posterior on three additional target frames that show the unseen side of the dice, and compute the KL divergence from the posterior to the prior, . The 2 extra target frames added for do not add any information, so for a consistent model. On the other hand, for a factored model like GQN, . We trained 12 CGQN models, and the mean was 4.25, the mean was 4.19, and the standard deviation of was 0.092. This suggests that CGQN is consistent, in the intended sense.
We have presented an architecture for learning generative models in the visual domain that can be conditioned on arbitrary points in time or space. Our models can extrapolate forwards or backwards in time, without needing to generate intermediate frames. Moreover, given a small set of contextual frames they can be used to render 3D scenes from arbitrary camera positions. In both cases, they generate consistent sets of frames for a given context, even in the presence of stochasticity. One limitation of our method is that the stochastic latent representation is of a fixed size, which may limit its expressivity in more complicated applications – fixing this limitation and testing on more complex datasets are good avenues for future work. Among other applications, video prediction can be used to improve the performance of reinforcement learning agents on tasks that require lookahead (Racanière et al., 2017; Buesing et al., 2018). In this context, the ability to perform “jumpy” predictions that look many frames ahead in one go is an important step towards agents that can explore a search space of possible futures, as it effectively divides time into discrete periods. This is an avenue we will be exploring in future work.
We would like to thank Dumitru Erhan and Mohammad Babaeizadeh for the code for SV2P and helping us in getting SV2P working on our datasets, and Lars Buesing, Yee Whye Teh, Antonia Creswell, Chongli Qin, Jonathan Uesato, and Valentin Dalibard for their very useful feedback in preparing this manuscript.
- Babaeizadeh et al. (2018) Babaeizadeh, M., Finn, C., Erhan, D., Campbell, R. H., and Levine, S. (2018). Stochastic variational video prediction. In International Conference on Learning Representations.
- Bartunov and Vetrov (2018) Bartunov, S. and Vetrov, D. (2018). Few-shot generative modelling with generative matching networks. In International Conference on Artificial Intelligence and Statistics, pages 670–678.
- Boots et al. (2014) Boots, B., Byravan, A., and Fox, D. (2014). Learning predictive models of a depth camera amp; manipulator from raw execution traces. In 2014 IEEE International Conference on Robotics and Automation (ICRA), pages 4021–4028.
- Bornschein et al. (2017) Bornschein, J., Mnih, A., Zoran, D., and J. Rezende, D. (2017). Variational memory addressing in generative models. In Advances in Neural Information Processing Systems, pages 3923–3932.
- Buesing et al. (2018) Buesing, L., Weber, T., Racanière, S., Eslami, S. M. A., Rezende, D. J., Reichert, D. P., Viola, F., Besse, F., Gregor, K., Hassabis, D., and Wierstra, D. (2018). Learning and querying fast generative models for reinforcement learning. CoRR, abs/1802.03006.
- Eslami et al. (2018) Eslami, S. M. A., Jimenez Rezende, D., Besse, F., Viola, F., Morcos, A. S., Garnelo, M., Ruderman, A., Rusu, A. A., Danihelka, I., Gregor, K., Reichert, D. P., Buesing, L., Weber, T., Vinyals, O., Rosenbaum, D., Rabinowitz, N., King, H., Hillier, C., Botvinick, M., Wierstra, D., Kavukcuoglu, K., and Hassabis, D. (2018). Neural scene representation and rendering. Science, 360(6394):1204–1210.
- Finn et al. (2016) Finn, C., Goodfellow, I. J., and Levine, S. (2016). Unsupervised learning for physical interaction through video prediction. In Advances in Neural Information Processing Systems 29: Annual Conference on Neural Information Processing Systems 2016, December 5-10, 2016, Barcelona, Spain, pages 64–72.
- Garnelo et al. (2018) Garnelo, M., Schwarz, J., Rosenbaum, D., Viola, F., Rezende, D. J., Eslami, S. M. A., and Teh, Y. W. (2018). Neural processes. In Theoretical Foundations and Applications of Deep Generative Models Workshop, ICML.
- Gregor et al. (2016) Gregor, K., Besse, F., Jimenez Rezende, D., Danihelka, I., and Wierstra, D. (2016). Towards conceptual compression. In Lee, D. D., Sugiyama, M., Luxburg, U. V., Guyon, I., and Garnett, R., editors, Advances in Neural Information Processing Systems 29, pages 3549–3557. Curran Associates, Inc.
- Gregor et al. (2015) Gregor, K., Danihelka, I., Graves, A., Rezende, D., and Wierstra, D. (2015). Draw: A recurrent neural network for image generation. In Bach, F. and Blei, D., editors, Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pages 1462–1471, Lille, France. PMLR.
- J. Rezende et al. (2016) J. Rezende, D., Danihelka, I., Gregor, K., Wierstra, D., et al. (2016). One-shot generalization in deep generative models. In International Conference on Machine Learning, pages 1521–1529.
- Kalchbrenner et al. (2017) Kalchbrenner, N., van den Oord, A., Simonyan, K., Danihelka, I., Vinyals, O., Graves, A., and Kavukcuoglu, K. (2017). Video pixel networks. In Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017, pages 1771–1779.
- Kingma and Welling (2014) Kingma, D. P. and Welling, M. (2014). Auto-encoding variational bayes. In International Conference on Learning Representations.
- Koch et al. (2015) Koch, G., Zemel, R., and Salakhutdinov, R. (2015). Siamese neural networks for one-shot image recognition. In ICML Deep Learning Workshop, volume 2.
- Liu et al. (2017) Liu, Z., Yeh, R. A., Tang, X., Liu, Y., and Agarwala, A. (2017). Video frame synthesis using deep voxel flow. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 4473–4481.
- Lotter et al. (2017) Lotter, W., Kreiman, G., and Cox, D. D. (2017). Deep predictive coding networks for video prediction and unsupervised learning. In International Conference on Learning Representations.
- Racanière et al. (2017) Racanière, S., Weber, T., Reichert, D., Buesing, L., Guez, A., Jimenez Rezende, D., Puigdomènech Badia, A., Vinyals, O., Heess, N., Li, Y., Pascanu, R., Battaglia, P., Hassabis, D., Silver, D., and Wierstra, D. (2017). Imagination-augmented agents for deep reinforcement learning. In Advances in Neural Information Processing Systems 30, pages 5690–5701.
- Reed et al. (2017a) Reed, S. E., Chen, Y., Paine, T., van den Oord, A., Eslami, S. M. A., Rezende, D. J., Vinyals, O., and de Freitas, N. (2017a). Few-shot autoregressive density estimation: Towards learning to learn distributions. CoRR, abs/1710.10304.
- Reed et al. (2017b) Reed, S. E., van den Oord, A., Kalchbrenner, N., Colmenarejo, S. G., Wang, Z., Chen, Y., Belov, D., and de Freitas, N. (2017b). Parallel multiscale autoregressive density estimation. In Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017, pages 2912–2921.
- Rezende et al. (2014) Rezende, D. J., Mohamed, S., and Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the 31st International Conference on Machine Learning, Proceedings of Machine Learning Research, pages 1278–1286. PMLR.
- Srivastava et al. (2015) Srivastava, N., Mansimov, E., and Salakhutdinov, R. (2015). Unsupervised learning of video representations using lstms. In Proceedings of the 32Nd International Conference on International Conference on Machine Learning - Volume 37, ICML’15, pages 843–852. JMLR.org.
- Vinyals et al. (2016) Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al. (2016). Matching networks for one shot learning. In Advances in Neural Information Processing Systems, pages 3630–3638.
Appendix A Appendix A: ELBO Derivation
The model is trained by maximizing an evidence lower bound, as in variational auto-encoders. We begin by expressing the log probability of the data in terms of the model’s latent variable .
The derivative of this objective is intractable because of the outside the expectation. Using Jensen’s inequality and substituting the and expectation leads to an unbiased estimator that collapses to the mean of the distribution. Instead, we use the standard trick of parameterizing an approximate posterior distribution . Instead of sampling from the prior we sample from the approximate posterior to get an equivalent objective.
Note that for efficiency, we often sample a subset of the target viewpoint-frame pairs () instead of conditioning the posterior and training on the entire set. We now apply Jensen’s inequality to get a lower bound (ELBO) that we maximize as a surrogate.
We can split the ELBO into 2 terms, the reconstruction probability and the KL-divergence between the prior and posterior.
Since we consider Gaussian probability distributions, the KL can be computed in closed form. For the reconstruction probability, we note that each of the target frames are generated independently conditional on and the corresponding viewpoint .
We can then apply the standard reparameterization trick (where we sample from a unit Gaussian and scale the samples accordingly). This gives us a differentiable objective where we can compute derivatives via backpropagation and update the parameters with stochastic gradient descent.
Appendix B Appendix B: Model Details and Hyperparameters
We first explain some of the hyper-parameters in our model. For reproducibility, we then give the hyper-parameter values that we used for the narrative concepts task and the 3D scene reconstruction task.
For CGQN, recall that we added Gaussian noise to the output of the renderer to ensure that the likelihood has support everywhere. We call the variance of this Gaussian distribution the “pixel-variance”. When the pixel-variance is very high, the ELBO loss depends a lot more on the KL-divergence between the prior and the posterior, than the mean squared error between the target and predicted images. That is, a small change in the KL term causes a much larger change in the ELBO than a small change in the mean squared error. As such, the training objective forces the posterior to match the prior, in order to keep the KL low. This makes the model predictions deterministic. On the other hand, when the pixel-variance is near zero, the ELBO loss depends a lot more on the mean squared error between the target and predicted images. In this case, the model allows the posterior to deviate far from the prior, in order to minimize the mean squared error. This leads to good reconstructions, but poor samples since the prior does not overlap well with the (possible) posteriors.
As such, we need to find a good “pixel-variance” that is neither too high, nor too low. In our case, we linearly anneal the pixel-variance from a value to over 100,000 training steps. Note that the other models, sSSM and SV2P, have an equivalent hyper-parameter, where the KL divergence is multiplied by a value . SV2P also performs an annealing-like strategy (Babaeizadeh et al., 2018).
For the traveling salesman dataset, we used the following parameters for the DRAW conditional prior/posterior net (Gregor et al., 2015, 2016). The rendering network was identical, except we do not have a conditional posterior, making it deterministic.
|nt||4||The number of DRAW steps in the network.|
|nf_to_hidden||64||The number of channels in the LSTM layer.|
|2.0||The initial pixel-variance.|
|0.5||The final pixel-variance.|
For the encoder network, , we apply a convolutional net to each image separately. The convolution net has 4 layers, with a ReLU non-linearity between each layer (but not after the last layer). The first layer has 8 channels, kernel shape of 2x2, and stride lengths of 2x2. The second layer has 16 channels, kernel shape of 2x2, and stride lengths of 2x2. The third layer has 32 channels, kernel shape of 3x3, and stride length of 1x1. The final layer has 32 channels, kernel shape of 3x3, and stride length of 1x1.
For all other datasets, we use the same encoder network, and similar hyper-parameters. For the MNIST Cube 3D scene reconstruction task, the main differences are that we use nt: 6, nf_to_hidden: 128, nf_dec: 128. We also had to use a slightly different annealing strategy for the pixel-variance. Simply annealing the variance down led to the KL-values collapsing to 0, and never rising back up. In other words, the predictions became deterministic. We use an annealing strategy somewhat similar to (Babaeizadeh et al., 2018). We keep the pixel-variance at 2.0 for the first 100,000 iterations, then keep it at 0.2 for 50,000 iterations, then keep it at 0.4 for 50,000 iterations, and then finally leave it at 0.9 until the end of training. The intuition is to keep the KL high first so that the model can make good deterministic predictions. Then, we reduce the pixel-variance to a low value (0.2) so that the model can capture the stochasticity in the dataset. Finally, we increase the pixel-variance so that the prior and the posteriors are reasonably similar.
Note that for each stochastic video prediction model we tested (CDNA, SV2P, sSSM), we swept over hyper-parameters, doing a grid search. We swept over size parameters, the learning rate, and the parameter used to control the KL divergence between the conditional posterior and the conditional prior. We ensured that we had tested hyper-parameter values slightly above, and below, the ones that we found worked best.
Appendix C Appendix C: Additional Consistent ELBO Experiments
We compare the negative ELBO over 3 target frames for CGQN and GQN for different pixel-variance values. The pixel-variance is effectively a hyper-parameter, that we can tune based on visual inspection of the samples, or an alternative metric. Good pixel-variance values for GQN and CGQN are different. As such, we included the pixel-variance values we found to work well for GQN, and the values we found to work well for CGQN. We tried pixel-variance values of , where is the number of targets and we tried . The plot for is shown in the main paper. We show the best half (two out of four) runs for GQN and CGQN for , , and below. Note that the ELBO for CGQN is better than the ELBO for GQN in all the plots below. If CGQN and GQN are equally good models (null hypothesis), the probability than CGQN did better than GQN in all these 6 curves is upper bounded by . So we can reject the null hypothesis with high statistical significance.