Robust learning with implicit residual networks
In this effort, we propose a new deep architecture utilizing residual blocks inspired by implicit discretization schemes. As opposed to the standard feed-forward networks, the outputs of the proposed implicit residual blocks are defined as the fixed points of the appropriately chosen nonlinear transformations. We show that this choice leads to the improved stability of both forward and backward propagations, has a favorable impact on the generalization power and allows to control the robustness of the network with only a few hyperparameters. In addition, the proposed reformulation of ResNet does not introduce new parameters and can potentially lead to a reduction in the number of required layers due to improved forward stability. Finally, we derive the memory-efficient training algorithm, propose a stochastic regularization technique and provide numerical results in support of our findings.
1 Introduction and related works
A large volume of empirical results has been collected in recent years illustrating the striking success of deep neural networks (DNNs) in approximating complicated maps by a mere composition of relatively simple functions LeCun et al. (2015). Universal approximation property of DNNs with a relatively small number of parameters has also been shown for a large class of functions Hanin (2017); Lu et al. (2017). The training of deep networks nevertheless remains a notoriously difficult task due to the issues of exploding and vanishing gradients, which become more apparent and noticeable with increasing depth Bengio et al. (1994). These issues accelerated efforts of the research community in an attempt to explain this behavior and gain new insights into the design of better architectures and faster algorithms. A promising approach in this direction was obtained by casting evolution of the hidden states of a DNN as a dynamical system E (2017), i.e.,
where for each layer , is a nonlinear transformation parameterized by the weights and , are the appropriately chosen spaces. In the case of a very deep network, when , it is convenient to consider the continuous time limit of the above expression such that
where the parametric evolution function defines a continuous flow through the input data . Parameter estimation for such continuous evolution can be viewed as an optimal control problem E et al. (2018), given by
where is a terminal loss function, is a regularizer, and is a probability distribution of the input-target data pairs . More general models additionally consider continuity in the “spatial” dimension as well by using differential Ruthotto and Haber (2018) or integral formulations Sonoda and Murata (2017). A continuous time formulation based on ordinary differential equations (ODEs) was proposed in Chen et al. (2018) with the state equation (2) of the form
In the work Chen et al. (2018), the authors relied on the black-box ODE solvers and used adjoint sensitivity analysis to derive equations for the backpropagation of errors through the continuous system.
The authors of Haber and Ruthotto (2017) concentrated on the well-posedness of the learning problem for ODE-constrained control and emphasized the importance of stability in the design of deep architectures. For instance, the solution of a homogeneous linear ODE with constant coefficients
is given by
where is the eigen-decomposition of a matrix , and is the diagonal matrix with the corresponding eigenvalues. Similar equation holds for the backpropagation of gradients. To guarantee the efficient propagation of information through the network, one must ensure that the elements of have magnitudes close to one. This condition, of course, is satisfied when all eigenvalues of the matrix are imaginary with real parts close to zero. In order to preserve this property, the authors of Haber and Ruthotto (2017) proposed several time continuous architectures of the form
When , , the equations above provide an example of a conservative Hamiltonian system with the total energy .
In the discrete setting of the ordinary feed forward networks, the necessary conditions for the optimal solution of (1)-(2) recover well-known equations for the forward propagation (state equation (2)), backward gradient propagation (co-state equation), and the optimality condition, to compute the weights (gradient descent algorithm), see, e.g, LeCun (1988). The continuous setting offers additional flexibility in the construction of discrete networks with the desired properties and efficient learning algorithms. Classical feed forward networks (Figure 1, left) is just the particular and the simplest example of such discretization which is prone to all the issues of deep learning. In order to facilitate the training process, a skip-connection is often added to the network (Figure 1, middle) yielding
where is a positive hyperparameter. Equation (5) can be viewed as a forward Euler scheme to solve the ODE in (3) numerically on the time grid with step size . While it was shown that such residual layers help to mitigate the problem of vanishing gradients and speed-up the training process He et al. (2016), the scheme has very restrictive stability properties Hairer et al. (1993). This can result in the uncontrolled accumulation of errors at the inference stage reducing the generalization ability of the trained network. Moreover, Euler scheme is not capable of preserving geometric structure of conservative flows and is thus a bad choice for the long time integration of such ODEs Hairer et al. (2006).
Memory efficient explicit reversible architectures can be obtained by considering time discretization of the partitioned system of ODEs in (4). The reversibility property allows to recover the internal states of the system by propagating through the network in both directions and thus does not require one to cache these values for the evaluation of the gradients. First, such architecture (RevNet) was proposed in Gomez et al. (2017), and without using a connection to discrete solutions of ODEs, it has the form
It was later recognized as the Verlet method applied to the particular form of the system in (4), see Haber and Ruthotto (2017); Chang et al. (2018). The leapfrog and midpoint networks are two other examples of reversible architectures proposed in Chang et al. (2018).
Other residual architectures can be also found in the literature including Resnet in Resnet (RiR) Targ et al. (2016), Dense Convolutional Network (DenseNet) Huang et al. (2017) and linearly implicit network (IMEXNet) Haber et al. (2019). For some problems, all of these networks show a substantial improvement over the classical ResNet but still have an explicit structure, which has limited robustness to the perturbations of the input data and parameters of the network. Instead, in this effort we propose new fully implicit residual architecture which, unlike the above mentioned examples, is unconditionally stable and robust. As opposed to the standard feed-forward networks, the outputs of the proposed implicit residual blocks are defined as the fixed points of the appropriately chosen nonlinear transformations as follows:
The right part of Figure 1 provides a graphical illustration of the proposed layer. The choice of the nonlinear transformation and the design of the learning algorithm are discussed in the next section.
2 Description of the method
We first motivate the necessity for our new method by letting the continuous model of a network be given by the ordinary differential equations in (4), that is:
An s-stage Runge-Kutta method for the approximate solution of the above equations is given by
The order conditions for the coefficients , , , , , and , which guarantee convergence of the numerical solution are well known and can be found in any topical text, see, e.g., Hairer et al. (1993). Note that when or for at least some , the scheme is implicit and a system of nonlinear equations has to be solved at each iteration which obviously increases the complexity of the solver. Nevertheless, the following example illustrates the benefits of using implicit approximations.
Linear stability analysis.
Consider the following linear differential system
and four simple discretization schemes:
and Verlet scheme
Due to linearity of the system in (6), we can write the generated numerical solutions as
The long time behavior of the discrete dynamics is hence determined by the spectral radius of the matrix which need to be less or equal to one for the sake of stability. For example, we have for the forward Euler scheme and the method is unconditionally unstable. Backward Euler scheme gives and the method is unconditionally stable. The corresponding eigenvalues of the trapezoidal scheme have magnitude equal to one for all and . Finally, the characteristic polynomial for the matrix of the Verlet scheme is given by , i.e., the method is only conditionally stable when .
Figure 2 illustrates this behavior for the particular case of . Notice that the flows of the forward and backward Euler schemes are strictly expanding and contracting which makes the training process inherently ill-posed as the dynamics are not easily invertible. Contrary, the implicit trapezoidal and explicit Verlet schemes seem to reproduce the original flow very well but the latter is conditional on the size of the step . Another nice property of the trapezoidal and Verlet schemes is their symmetry with respect to the exchanging and . Such methods play a central role in the geometric integration of reversible differential flows and are handy in the construction of the memory efficient reversible network architectures. Conditions for the reversibilty of general Runge-Kutta schemes can be found in Hairer et al. (2006).
2.1 Implicit ResNet.
Motivated by the discussion above, we propose an implicit variant of the residual layer given by
where , , are the input, output and parameters of the layer and is a nonlinear function.
To solve the nonlinear equation in (7), consider the equivalent minimization problem
One way to construct the required solution is by applying the simple gradient descent algorithm
or the more advanced nonlinear conjugate gradient method
with the conjugate direction and the corresponding optimal step size of the form
The choice of the parameter determines the particular variant of the conjugate gradient method. For example, Fletcher-Reeves and Polak-Ribiére update rules are often used in practice Nocedal and Wright (2006). Note that the required gradients of the residual in (8) can be efficiently computed with the automatic differentiation capabilities of any standard deep learning framework making it possible to efficiently interface existing solvers.
Finally, simple fixed-point iterations may suffice if one can ensure the convergence of such iterative procedure
In this case, there is no need to interface external solvers which simplifies implementation of the algorithm.
It is worth noting that, even though the nonlinearity in (7) adds to the complexity of the forward propagation, the direct backpropagation through the nonlinear solver is not required as is shown below.
Using the chain rule we can easily find the Jacobian matrices of the imlpicit residual layer as follows
The backpropagation formulas then follow immediately
where is a solution to the linear system
The ability of automatic differentiation software to efficiently compute vector-Jacobian products can be combined with iterative linear solvers to find the solution of this linear system. Moreover, only one such linear solve is required at each layer. Additionally, Neumann series expansion of the matrix inverse
can be used when , i.e., for small values of or when is a contraction. By analogy with fixed-point iterations, this would remove the depedence on external linear solvers and simplify the implementation.
To provide an efficient implementation of the proposed algorithm, we need to satisfy two conditions. Firstly, the existence of the stable fixed point in (7) must be ensured to guarantee the overall convergence of the algorithm. Second, the dependence on external libraries should be minimized or avoided completely when possible.
The first condition can be satisfied by imposing a hard constraint on the parameters of the network to make it globally diffusive or conservative. In combination with explicit solvers, this approach has been utilized, for instance, in Haber and Ruthotto (2017); Ruthotto and Haber (2018). Instead, we propose a more flexible regularization technique which on one side enforces an appropriate structure on the parameters of the network but also promotes computational efficiency.
The idea of the approach is to rely exclusively on the fixed-point iterative procedure. Assume that successful fixed-point iterations have been performed such that the residuals of the successive approximations are monotonically decaying for . By taking , the first iteration has the form
which is simply the output of the standard residual layer. Hence, even if not converged, the iterative process will perform at least not worse than an explicit layer.
The proposed regularizer has the form
where each is a function of . Minimization of leads to the reduction of the number of iterations required for the convergence to the specified tolerance. Coefficients control the convergence rate and smaller values of correspond to faster rates.
Successfully trained network equipped with such structure is expected to have enhanced stability properties on the training dataset. To improve robustness of the network to the perturbations of the input data, we introduce additional regularizer
where are the states of the same network corresponding to random perturbations of the input data. These perturbations are not fixed prior to the training and are generated independently for each training epoch. The aim of this regularizer is to avoid the accumulation of perturbations across the network. The value of controls the dumping properties of the trained network, smaller values of corresponding to faster dumping.
Figure 3 provides an example of the proposed architecture with layers and concurrent perturbation streams. The type and magnitude of perturbations can be varied and depend on the particular application.
The proposed implicit layer in Figure 1 has a self feedback connection which is also present in some recurrent networks. The main difference, however, is that in our case, this feedback loop is designed with a specific goal of driving the state of the network to the stable fixed point.
The authors of Chen et al. (2018) proposed a neural network architecture and learning algorithm to estimate continuous time differential equations. This approach is beneficial in a number of applications, allows for the adaptive error control and saves memory by using adjoint equations for the backward propagation. The authors also applied implicit methods with adjoint method in Chen and Duvenaud (2019) but have not considered backpropagating through implicit solvers at the discrete level. In comparison, our approach has a predictable computational cost of the fixed-grid solver, same parameter complexity and provides the backpropagation and regularization algorithms to train robust and stable implicit discrete time networks.
Lipschitz continuous architectures.
Enforcing Lipschitz continuity is a popular approach to stabilize neural networks. Weight normalization Miyato et al. (2018); Gouk et al. (2018) and Lipschitz based regularization Qi (2019); Tsuzuku et al. (2018) are among possible options. Implicit layers can be used as another way to control the Lipschitz constant of the whole network.
The proposed residual architecture can be easily implemented using any existing deep learning framework such as PyTorch or Tensorflow. The code snippet in Listing 1 gives an example of such implementation in PyTorch. Firstly, we disable gradient tracking to avoid backpropagating through the nonlinear solver and then we compose the output of the layer with the custom_backprop function which represents the identity map during forward propagation and is responsible for the linear solve in the backpropagation.
Computational cost of the implicit network is necessarily larger when compared to ResNet of the same depth since additional nonlinear and linear solves are required at each layer. The cost of the linear solver strongly depends on the structure of the linear operator. For general dense matrices it is on the order of . In practice, the dimension of hidden states is often not very large or the linear operator is of special structure. For instance, sparse convolutional operators should not be cast into the matrix form and the corresponding linear systems can be solved by iterative methods which only require one to know how to apply a particular operator to the given tensor. The cost of a general nonlinear solver is more difficult to estimate since the convergence is highly dependent on the initialization. In practice, however, we aim to avoid general nonlinear solvers in favor of fixed-point iterations. The cost of the implicit network is then at most times larger than that of ResNet, where is the maximum allowed number of such iterations.
The memory complexity of the raw implicit network is the same as that of ResNet since only input and output to each layer need to be cached to perform forward and backward propagations. The regularizer in (9), however, additionally requires caching all intermediate iterations of the fixed point iterative procedure. This may result in some memory overhead at the start of the training procedure which nevertheless will relax once the number of required iterations reduces.
In all examples below, we use residual networks derived from the ODE model with hidden states
where is the input to the network.
We choose ReLU activation, set the tolerance of the nonlinear and linear solvers to and limit the allowed number of fixed point iterations to per layer.
In our first example, we aim to approximate the following one-dimensional function
The initial stationary weights , are drawn from the standard normal distribution and then normalized as
where is the spectral radius of the matrix . The bias is initially set to zero.
We used the training dataset of equally distributed points, and trained the networks using Adam optimizer and loss function
The initial learning rate was set to and reduced by the factor every epochs.
Figure 4 shows the convergence of the loss on the training and validation datasets for several residual networks with and layers. It is seen that explicit ResNet has the best training accuracy among networks with layers but is second worst on the validation dataset. Figure 7 explains this behavior by showing that explicit ResNet is unstable, same figure shows that all implicit networks are stable. Increasing stability on the coarse time grids is the reason that implicit networks perform a bit worse in terms of training loss. At the same time, all the networks with layers perform equally well as expected except the explicit ResNet which, according to Figure 6 is stuck at the local minimum and is struggling in propagating its gradients. Figure 8 illustrates behavior of networks with layers in more detail. Additionally, Figure 5 shows the convergence of fixed-point and Neumann iterations.
For the second example, we consider a classification problem from Haber and Ruthotto (2017). The dataset consists of points organized in two differenetly labeled spirals. Every other point was removed to be used as the validation dataset.
We use ResNet with 100 layers to guarantee stability and apply the stochastic regularization technique in (10). Figure 9 illustrates classification results of the original network and three networks trained with Gaussian noise perturbations with standard deviations , and . One can see that the proposed regularization technique results in a more accurate and robust networks than original ResNet.
This material is based upon work supported in part by: the U.S. Department of Energy, Office of Science, Early Career Research Program under award number ERKJ314; U.S. Department of Energy, Office of Advanced Scientific Computing Research under award numbers ERKJ331 and ERKJ345; the National Science Foundation, Division of Mathematical Sciences, Computational Mathematics program under contract number DMS1620280; and the Behavioral Reinforcement Learning Lab at Lirio LLC.
- (1994) Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks 5 (2), pp. 157–166. External Links: Cited by: §1.
- (2018) Reversible architectures for arbitrarily deep residual neural networks. In Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §1.
- (2019) Neural networks with cheap differential operators. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox and R. Garnett (Eds.), pp. 9961–9971. Cited by: §2.1.
- (2018) Neural ordinary differential equations. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi and R. Garnett (Eds.), pp. 6571–6583. Cited by: §1, §2.1.
- (2018) A mean-field optimal control formulation of deep learning. Research in the Mathematical Sciences 6 (1), pp. 10. External Links: Cited by: §1.
- (2017) A proposal on machine learning via dynamical systems. Communications in Mathematics and Statistics 5 (1), pp. 1–11. External Links: Cited by: §1.
- (2017) The reversible residual network: backpropagation without storing activations. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan and R. Garnett (Eds.), pp. 2214–2224. Cited by: §1.
- (2018) Regularisation of neural networks by enforcing lipschitz continuity. arXiv preprint arXiv:1804.04368. Cited by: §2.1.
- (2019) IMEXnet: a forward stable deep neural network. arXiv e-prints, pp. arXiv:1903.02639. External Links: Cited by: §1.
- (2017) Stable architectures for deep neural networks. Inverse Problems 34 (1), pp. 014004. External Links: Cited by: §1, §1, §2.1, §3.
- (2006) Geometric numerical integration: structure-preserving algorithms for ordinary differential equations. Springer Series in Computational Mathematics, Vol. 31, Springer-Verlag Berlin Heidelberg. Cited by: §1, §2.
- (1993) Solving ordinary differential equations i, nonstiff problems. Springer Series in Computational Mathematics, Vol. 8, Springer-Verlag Berlin Heidelberg. Cited by: §1, §2.
- (2017) Universal function approximation by deep neural nets with bounded width and ReLU activations. arXiv preprint arXiv:1708.02691. Cited by: §1.
- (2016) Identity mappings in deep residual networks. In Computer Vision – ECCV 2016, B. Leibe, J. Matas, N. Sebe and M. Welling (Eds.), Cham, pp. 630–645. External Links: Cited by: §1.
- (2017-07) Densely connected convolutional networks. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
- (2015) Deep learning. Nature 521, pp. 436–444. Cited by: §1.
- (1988) A theoretical framework for back-propagation. In Proceedings of the 1988 Connectionist Models Summer School, D. Touretzky, G. Hinton and T. Sejnowsky (Eds.), CMU, Pittsburgh, Pa, pp. 21–28. Cited by: §1.
- (2017) The expressive power of neural networks: a view from the width. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan and R. Garnett (Eds.), pp. 6231–6239. Cited by: §1.
- (2018) Spectral normalization for generative adversarial networks. arXiv preprint arXiv:1802.05957. Cited by: §2.1.
- (2006) Numerical optimization. Springer Science & Business Media. Cited by: §2.1.
- (2019) Loss-sensitive generative adversarial networks on lipschitz densities. International Journal of Computer Vision. External Links: Cited by: §2.1.
- (2018) Deep Neural Networks Motivated by Partial Differential Equations. arXiv e-prints, pp. arXiv:1804.04272. External Links: Cited by: §1, §2.1.
- (2017) Double continuum limit of deep neural networks. In ICML Workshop on Principled Approaches to Deep Learning, Cited by: §1.
- (2016) Resnet in resnet: generalizing residual architectures. arXiv preprint arXiv:1603.08029. Cited by: §1.
- (2018) Lipschitz-margin training: scalable certification of perturbation invariance for deep neural networks. In Advances in Neural Information Processing Systems, pp. 6541–6550. Cited by: §2.1.