A Distributed Neural Network Architecture for Robust Non-Linear Spatio-Temporal Prediction
We introduce a distributed spatio-temporal artificial neural network architecture (DISTANA). It encodes mesh nodes using recurrent, neural prediction kernels (PKs), while neural transition kernels (TKs) transfer information between neighboring PKs, together modeling and predicting spatio-temporal time series dynamics. As a consequence, DISTANA assumes that generally applicable causes, which may be locally modified, generate the observed data. DISTANA learns in a parallel, spatially distributed manner, scales to large problem spaces, is capable of approximating complex dynamics, and is particularly robust to overfitting when compared to other competitive ANN models. Moreover, it is applicable to heterogeneously structured meshes.
Modeling and predicting non-linear spatio-temporal process dynamics is challenging for current pattern recognition systems . Representative problem domains range from the analysis of dynamic brain activities in neuroscience , over video streams , information flow in social networks , and traffic predictions , to climate and weather forecasts [8, 10]. In all cases, the major challenge is to infer, model, and predict the underlying causes that generate the perceived data stream, propagating the involved causal dynamics through graphs and distributed sensor meshes. A key property, which all spatio-temporal processes have in common, is that some generally underlying principles—such as physics when observing natural processes—will apply irrespective of time or location. As a result, the same predictable patterns—individually modified by local spatial and temporal influences—will be observable repeatedly at different spatial locations and points in time.
We introduce DISTANA, a distributed spatio-temporal artificial neural network architecture, which actively searches for such characteristics in time series data. DISTANA learns predictive, spatio-temporal, neural network kernels (PKs), which are applied to all nodes of a mesh. Additional information routing transition kernels (TKs) laterally connect the PKs. Both PKs and TKs, respectively, share their weights, thus applying the same operations at different locations. This enables efficient parallel computation in and learning from all spatial locations in the mesh. Moreover, DISTANA is predisposed to identify the universal, recurring causes of the observed pattern dynamics. Compared to seven other ANN models, including convolutional neural networks (CNNs), recurrent neural networks (RNNs), and combinations of both (e.g. ConvLSTM), DISTANA reaches both higher accuracy and robustness at approximating circularly propagating waves, it is critically less prone to overfitting, and it bears the potential to handle heterogeneously distributed sensor meshes. Thus, in the near future we intend to apply DISTANA to related, but more challenging real-world problems, such as learning to predict the partially chaotic processes that generate our weather and climate.
2 Related Work
While CNNs  have been shown to efficiently and accurately process spatially distributed information such as images, RNNs—and long short-term memory cells (LSTMs)  in particular—were designed to handle temporally distributed data such as time series. Recently, Shi et al.  proposed ConvLSTM, a combination of CNNs and LSTMs resulting in a convolution-gating architecture, which processes spatial and temporal information simultaneously. GridLSTM , on the other hand, extends LSTMs to process not only temporal but also spatial data dimensions sequentially. DISTANA belongs to a third related class of architectures, which is referred to as graph neural networks (GNNs) . GNNs treat graph vertices and edges in two different neural network components. Unlike earlier GNNs, however, DISTANA integrates LSTM structures, projects the graph, i.e. its mesh, onto a metrical space, and assumes universal causes underlying the observable spatio-temporal data.
3 Model Description
DISTANA is a two-network architecture that consists of a PK network, which generates dynamic predictions at each desired spatial position, and a TK network, which models transitions between (two or more) adjacent PKs. The PKs and the TKs, which share their respective weights, are applied in a sensor mesh. As depicted in Figure 1, the PK and TK networks can be applied simultaneously in space, processing spatially distributed data. Each PK instance receives (1) dynamic input, which is subject to prediction and changes over time, (2) static information, which stays constant and characterizes the location of each PK, and (3) lateral input from neighboring PKs. The TK network—making our approach unique—is introduced to model location-sensitive transitions between PKs and thus to enable local context-dependent spatial information propagation. In principle, both PK and TK networks can have arbitrary topologies and may incorporate recurrent connections (cf. Figure 1).
In two experiments, which differ in the data sets used, several ANN architectures including fully connected networks, CNNs, and RNNs are compared with DISTANA, modeling a wave-like spatio-temporal process (cf. Figure 2), which is distributed in a mesh. Train and test errors are calculated as mean squared errors between network output and target, being the network input shifted by one time step, requiring the networks to predict the next time step of a 2D circular wave sequence. The test error is calculated over 65 time steps of closed loop performance, where the network feeds itself with its own dynamic predictions from the previous time step. The closed loop begins after 15 steps of teacher forcing, which ground the recurrent activity in the network.
4.1 Data Set 1
Initially, a basic data set is created where single waves are generated propagating outwards. The waves are not reflected at the borders, yielding comparably simple dynamics. Waves were generated using
where is the wave height of the field at a certain position and time, defines the oscillating wave height considering the distance to the wave center and the current time step , and can be described as decaying expression, which makes waves decay with respect to their distance to the wave center over time while considering a decay factor . For large values of the wave fades quicker than for small values of . Constant is the wave velocity and field values which have not been reached by the wave in time step are explicitly set to zero.
|Model (#pars)||Train error||Test error||Inf. time||1-train-ex.||Var. wave|
Table 1 shows the performance of all compared models at approximating the circular wave. Additionally to the train and test error, we report the number of parameters and the inference time of one sequence (consisting of 80 time steps) for each model. In order to rigorously test all models for their generalization abilities, we also trained them on one single sequence and again computed the test error on unseen sequences (test error 1-train-ex.). Furthermore, to elaborate the models’ abilities to approximate variable dynamics, we trained them on waves that travel with varying velocities (test error var. wave). See Figure 3 for a performance visualization. Spatial scalability, which indicates whether a model can be applied to an input field of different resolution, is reported in the subsequent model descriptions.
Two baselines were created as upper bounds. Baseline was calculated by assuming an identity function that returns the input directly, whereas Baseline zero assumes a model that always predicts zeros.
Fully Connected Networks
A naive and spatially not scalable approach to model the circular wave is a fully connected linear network (FC-Linear), with cells, receiving the flattened input. A more elaborated model is FC-LSTM, which replaces the linear layer of FC-Linear by a 256-cell LSTM layer to facilitate temporal information processing.
To reduce the number of parameters, defining a spatially scalable model, numerous CNNs with different kernel sizes, a varying number of feature maps, and two convolutional layers were evaluated. The best results, which are reported here, were achieved by using a kernel size of and one feature map.
Temporal Convolution Network (TCN)
TCNs, as a spatially scalable approach, were applied with three 3D convolution layers, each with a kernel and feature maps. Other depths or kernel sizes did not seem to improve performance.
The CNN approach was extended by inserting a fully connected LSTM layer—making it not spatially scalable—after a variable number of layers (one to three convolution layers followed by the LSTM layer and one to three transposed convolution layers). Best results were achieved with one convolution followed by a flat LSTM layer and a transposed convolution with skip connection.
Two models of the spatially scalable ConvLSTM architecture, both with two layers and kernel size three, are reported: ConvLSTM1 with one feature map in both layers, and ConvLSTM8 with eight feature maps in the first layer, which are reduced to one in the second layer.
GridLSTM and BiGridLSTM
Spatially scalable GridLSTM models are evaluated. GridLSTM runs forward over all three data dimensions. BiGridLSTM processes the data forward in time but bidirectionally over space.
Our model (DISTANA)
The PK consists of a linear layer, followed by a layer of either four or 26 LSTM cells, and another linear layer. The TK is a simple linear layer and is used—like the other linear layers—without activation function. As some of the other models above, DISTANA is spatially scalable.
4.2 Data Set 2
To increase data complexity, a second set was created where waves are reflected at borders, such that wave fronts become interactive. We focus our analysis on the most promising architectures determined above. For wave data generation, the following two-dimensional wave equation:
was solved numerically using the second order central differences approach
where stands for a variable of function and is the approximation step size with respect to the considered variable .
which can be solved for to obtain an equation for computing the state of the field at a desired position in the subsequent time step
Boundary conditions in both space (when or , analogously for ) and time domains (in time step 0) are treated as zero. The following variable choices were met: , and . The field was initialized using the Gaussian bell curve
with amplitude factor , wave width in - and -direction and being the starting point or center of the circular wave.
The unfolding dynamics of higher complexity are much harder to predict, as can be seen in Figure 4. None of the previously tested architectures was able to approximate the dynamics satisfactorily, as can be seen in the test error rates of Table 2 remaining larger than the baseline test errors. Accordingly, DISTANA was adapted slightly in three ways, described in the following.
|Model (#pars)||Train error||Test error||Inf. time|
The size of the preprocessing feed forward layer in the PK was increased from one to four neurons.
Next to an increased preprocessing layer as in DISTANA v1, the lateral input neurons were changed from one to eight. Furthermore, instead of using TKs as information transition tools, the WKs were designed to communicate with all eight grid neighbors directly. Each lateral input neuron consistently receives input from a particular neighbor, depending on the direction.
While the same changes as in DISTANA v2 were applied, the number of lateral output neurons was increased from one to eight, analogously to the lateral input neurons.
As a result, DISTANAv2 and DISTANAv3 strongly outperform the simpler DISTANA version as well as TCN and ConvLSTM. Table 2 shows that DISTANAv2 not only reaches the lowest training error but also yields the best generalization performance.
Fig. 4 shows that when closed loop predictions unfold after 15 steps of teacher forcing, DISTANAv2 and DISTANAv3 approximate the target value still rather well while the other ANN architectures start to strongly deviate from the target values after only 10 to 15 closed-loop prediction steps.
Online video material
Several ANN architectures were compared at approximating a spatio-temporal process, the circular wave, using two different complexity scenarios. The performance comparisons of data set one show that only ConvLSTM and our model, DISTANA, yield smaller test errors than the two baselines. Recall that here the closed loop performance over prediction time steps was measured, which is much harder than just next time step prediction, as it requires both intrinsic model stability and the maintenance of plausible ongoing dynamics. While the reported accuracy for standard training and simple dynamics is in favor of ConvLSTM, DISTANA proved robust to few and variable training data (see Figure 3 and last two columns in Table 1), even with a network that contains only 106 parameters. In these latter cases, ConvLSTM1 tended to approach the zero baseline, while DISTANA still predicted the actual wave amplitude, albeit with a slightly differing wave frequency, which prevented an even smaller error. ConvLSTM8 reached outstanding test errors but tended to overfit heavily, as visible by the comparably bad testing error when trained on one training example only (Figure 3). These findings were corroborated by the evaluations in a second, more complex data set, in which waves were reflected at borders and thus heavily interacted with each other. All other considered architectures failed to generate lasting closed-loop predictions, except for two variants of DISTANA that considered lateral information propagation explicitly (Figure 4). DISTANA did not tend to overfit and generalized very well, because it assumes equal dynamic propagation principles throughout space and time by design.
So far, we have only considered regularly distributed grid data, where distances between single measurement points are identical. However, this data situation is not given in many applications, such as neural information processing in the brain, traffic predictions on roads, data propagation on graphs, or weather prediction given weather station network, radar, and satellite data. In all these cases, distances and the number of neighboring vertices vary and we are currently enhancing the TKs in DISTANA accordingly. Ongoing work shows that DISTANA can indeed handle irregularly distributed sensor meshes. We are particularly interested in whether DISTANA will detect coherent and predictable structures in weather data, which will indicate its scalable applicability to short-range weather forecasting.
We thank Georg Martius for inspiring ideas and Nicholas Krämer for sharing expertise in complex data generation; we thank the International Max Planck Research School for Intelligent Systems (IMPRS-IS) for supporting Matthias Karlbauer, and gratefully mention that this work is funded by the German Research Foundation (DFG) under Germany’s Excellence Strategy – EXC-Number 2064/1 – Project Number 390727645.
- (2017) Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine 34 (4), pp. 18–42. Cited by: §1.
- (2018) Exploratory social network analysis with pajek: revised and expanded edition for updated software. Vol. 46, Cambridge University Press. Cited by: §1.
- (1997) Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §2.
- (2015) Grid long short-term memory. arXiv preprint arXiv:1507.01526. Cited by: §2.
- (2014) Large-scale video classification with convolutional neural networks. In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pp. 1725–1732. Cited by: §1.
- (2014) NeuCube: a spiking neural network architecture for mapping, learning and understanding of spatio-temporal brain data. Neural Networks 52, pp. 62–76. Cited by: §1.
- (1989) Backpropagation applied to handwritten zip code recognition. Neural computation 1 (4), pp. 541–551. Cited by: §2.
- (2015) Information granularity, big data, and computational intelligence. W. Pedrycz and S. M. Chen (Eds.), Studies in Big Data, Vol. 8, pp. 389–408. External Links: Cited by: §1.
- (2008) The graph neural network model. IEEE Transactions on Neural Networks 20 (1), pp. 61–80. Cited by: §2.
- (2015) Convolutional LSTM network: a machine learning approach for precipitation nowcasting. In Advances in Neural Information Processing Systems 28, C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama and R. Garnett (Eds.), pp. 802–810. External Links: Cited by: §1, §2.
- (2017) LSTM network: a deep learning approach for short-term traffic forecast. IET Intelligent Transport Systems 11 (2), pp. 68–75. Cited by: §1.