Ada-LISTA: Learned Solvers Adaptive to Varying Models
Neural networks that are based on unfolding of an iterative solver, such as LISTA (learned iterative soft threshold algorithm), are widely used due to their accelerated performance. Nevertheless, as opposed to non-learned solvers, these networks are trained on a certain dictionary, and therefore they are inapplicable for varying model scenarios. This work introduces an adaptive learned solver, termed Ada-LISTA, which receives pairs of signals and their corresponding dictionaries as inputs, and learns a universal architecture to serve them all. We prove that this scheme is guaranteed to solve sparse coding in linear rate for varying models, including dictionary perturbations and permutations. We also provide an extensive numerical study demonstrating its practical adaptation capabilities. Finally, we deploy Ada-LISTA to natural image inpainting, where the patch-masks vary spatially, thus requiring such an adaptation.
Sparse coding is the task of representing a noisy signal as a combination of few base signals (called “atoms”), taken from a matrix – the “dictionary”. This is represented as the need to compute such that
where the -norm counts the non-zero elements, is the cardinality of the representation, and is often redundant (). Among the various approximation methods for handling this NP-hard task, an appealing approach is a relaxation of the to an -norm using Lasso or Basis-Pursuit Tibshirani (1996); Chen et al. (2001),
An effective way to address this optimization problem uses an iterative algorithm such as ISTA (Iterative Soft Thresholding Algorithm) Daubechies et al. (2004), where the solution is obtained by iterations of the form
where is the step size determined by the largest eigenvalue of the Gram matrix , and is the soft shrinkage function. Fast-ISTA (FISTA) Beck and Teboulle (2009) is a speed-up of the above iterative algorithm, which should remind the reader of the momentum method in optimization.
As a side note, we mention that ISTA has a much wider perspective when aiming to minimize a function of the form
The above fits various optimization problems such as a projected gradient descent over an indicator function , the matrix completion problem Mazumder et al. (2010), portfolio optimization Boyd and Vandenberghe (2004), non-negative matrix factorization Sprechmann et al. (2015), and more.
Returning to the realm of sparse coding, the seminal work of LISTA (Learned-ISTA) Gregor and LeCun (2010) has shown that by unfolding iterations of ISTA and freeing its parameters to be learned, one can achieve a substantial speedup over ISTA (and FISTA). Particularly, LISTA uses the following re-parametrization:
where and re-parametrize the matrices and correspondingly. These two matrices and the scalar thresholding value are collectively referred to as – the parameters to be learned. The model, denoted as , is trained by minimizing the squared error between the predicted sparse representations at the th unfolding , and the optimal codes obtained by running ISTA itself,
Once trained, LISTA requires only the test signals during inference, without their underlying dictionary. It has been shown in Gregor and LeCun (2010) that LISTA generalizes well for signals of the same distribution as in the train set, allowing a significant speedup versus its non-learned counterparts. This may be explained by the fact that while non-learned solvers do not make any assumption on the input signals, LISTA fits itself to the input distribution. More specifically, in sparse coding, the input signals are restricted to a union of low-dimensional Gaussians, as they are generated by a linear combination of few atoms. By focusing on such signals solely, this allows LISTA to achieve its acceleration. Note, however, that the original dictionary is hard-coded into the model weights via the ground truth solutions used during the supervised training. Given a new test sample that emerges from a slightly deviated (yet known) model/dictionary, LISTA will most likely deteriorate in performance, whereas ISTA and FISTA are expected to provide a robust and consistent result, as they are agnostic to the input signals and dictionary.
From a different point of view, a drawback of LISTA is its relevance to a single dictionary, requiring a separate and renewed training if the model evolves over time. Such is the case in video related applications as enhancement Protter and Elad (2008) or surveillance Zhao et al. (2011), where the dictionary should vary along time. Similarly, in some image restoration problems, the model encapsulated by the dictionary is often corrupted by an additional constant perturbation, e.g., the sensing matrix in compressive sensing Kulkarni et al. (2016), the blur kernel in non-blind image deblurring Tang et al. (2014), and a spatially-varying mask in image inpainting Mairal et al. (2007). In all these cases, deployment of the classic framework of LISTA necessitates a newly trained network for each new dictionary. An alternative to the above is incorporating LISTA as a fixed black-box denoiser, and merging it within the plug-and-play Venkatakrishnan et al. (2013) or RED Romano et al. (2017) schemes, significantly increasing the inference complexity.
Our aim in this work is to extend the applicability of LISTA to scenarios of model perturbations and varying signal distributions. More specifically,
We bridge the gap between the efficiency and the fast convergence rate of LISTA, and the high adaptivity and applicability of ISTA (and FISTA), by introducing “Ada-LISTA” (Adaptive-LISTA). Our training is based on pairs of signals and their corresponding dictionaries, learning a generic architecture that wraps the dictionary by two auxiliary weight matrices. At inference, our model can accommodate the signal and its corresponding dictionary, allowing to handle a variety of model modifications without repetitive re-training.
We perform extensive numerical experiments, demonstrating the robustness of our model to three types of dictionary perturbations: permuted columns, additive Gaussian noise, and completely renewed random dictionaries. We demonstrate the ability of Ada-LISTA to handle complex and varying signal models while still providing an impressive advantage over both learned and non-learned solvers.
We prove that our modified scheme achieves a linear convergence rate under a constant dictionary. More importantly, we allow for noisy modifications and random permutations to the dictionary and prove that robustness remains, with an ability to reconstruct the ideal sparse representations with the same linear rate.
We demonstrate the use of our approach on natural image inpainting, which cannot be directly used with hard-coded models as LISTA. We show a clear advantage of Ada-LISTA versus its non-learned counterparts.
Adopting a wider perspective, our study contributes to the understanding of learned solvers and their ability to accelerate convergence. Common belief suggests that the signal model should be structured and fixed for successful learning of such solvers. Our work reveals, however, that effective learning can be achieved with a weaker constraint – having a fixed conditional distribution of the data given the model .
The LISTA concept of unfolding the iterations of a classical optimization scheme into an RNN-like neural network, and freeing its parameters to be learned over the training data, appears in many works. These include an unsupervised and online training procedure Sprechmann et al. (2015), a multi-layer version Sulam et al. (2019), a gated mechanism compensating shrinkage artifacts Wu et al. (2020), as well as reduced-parameter schemes Chen et al. (2018); Liu et al. (2019). This paradigm has been brought to various applications, such as compressed sensing, super-resolution, communication, MRI reconstruction Zhang and Ghanem (2018); Metzler et al. (2017); Wang et al. (2015); Borgerding et al. (2017); Sun et al. (2016); Hershey et al. (2014), and more. A prominent line of work investigates the success of such learned solvers from a theoretical point of view Xin et al. (2016); Wang et al. (2016); Moreau and Bruna (2016); Giryes et al. (2018); Zarka et al. (2019). Most of these consider a fixed signal model, with the exception of “robust-ALISTA” Liu et al. (2019) that introduces an adaptive variation of LISTA. This scheme, however, is restricted to small model perturbations, and cannot address more complicated model variations. A more detailed discussion of the relevant literature in relation of our study appears in Section 4.
2 Proposed Method
Thus far, as depicted in Figure 1, one could either benefit from a high convergence rate using a learned solver as LISTA, while restricting the signals to a specific model, or employ a non-learned and less effective solver as ISTA/FISTA that is capable of handling any pair of signal and its generative model. In this paper we introduce a novel architecture, termed “Adaptive-LISTA” (Ada-LISTA), combining both benefits. Beyond enjoying the acceleration benefits of learned solvers, we incorporate the dictionary as part of the input at both training and inference time, allowing for adaptivity to different models. Figure 2 provides our suggested architecture, based on the following:
Definition 1 (Ada-LISTA).
The Ada-LISTA solver is defined
The signal and the dictionary are the inputs, and the learned parameters are and .
The inference (for both ISTA and FISTA) and the training procedures of Ada-LISTA are detailed in Algorithms 1 and 2 correspondingly. We consider a similar loss as in Equation 7, while also incorporating the concurrent dictionaries,
This learning regime is supervised, requiring reference representations to be computed using ISTA. An unsupervised alternative can be envisioned, as in Sprechmann et al. (2015); Golts et al. (2018), where the loss is
In this paper we shall focus on the supervised mode of learning, leaving the unsupervised alternative to future work.
Several key questions arise on the applicability of the above learned solver: Does it work? and if so, is performance compromised by Ada-LISTA, as opposed to training LISTA for each separate model? To what extent can it be used? Can it handle completely random models? Can theoretical guarantees be provided on its convergence rate, or adaptation capability? We aim to answer these questions, and we start with a theorem on the robustness of our scheme by proving linear rate convergence under varying model.
3 Ada-LISTA: Theoretical Study
For the following study, we consider a reduced scheme of Ada-LISTA with a single weight matrix, so as to avoid complication in theorem conditions. We emphasize, however, that the same claims can be derived for our original scheme.
Definition 2 (Ada-LISTA – Single Weight Matrix).
Ada-LISTA with a single weight matrix is defined by
We start by recalling the definition of mutual coherence between two matrices:
Definition 3 (Mutual Coherence).
Given two matrices, and , if the diagonal elements of are equal to , then the mutual coherence is defined as
where and are the th and th columns of and .
Our first goal is to prove that Ada-LISTA is capable of solving the sparse coding problem in linear rate. We show that if all the signals emerge from the same dictionary , there exists a weight matrix and threshold values such that the recovery error decreases linearly over iterations. The following theorem indicates that if Ada-LISTA’s training reaches its global minimum, the rate would be at least linear. In this part, we follow the steps in Zarka et al. (2019), which generalize the proof of ALISTA Liu et al. (2019) to noisy signals. The proof for Theorem 1 appears in Appendix A.
Theorem 1 (Ada-LISTA Convergence Guarantee).
Consider a noisy input . If is sufficiently sparse,
and the thresholds satisfy the condition
with , , and , then the support in the th iteration of Ada-LISTA (Definition 2) is included in the support of , and its values satisfy
We proceed by claiming that Ada-LISTA can be adaptive to model variations. In this setting, we argue that the signal can originate from different models, and nonetheless there exist global parameters such that Ada-LISTA will converge in linear rate to the original representation. Our Theorem exposes the key idea that, as opposed to LISTA which corresponds to a single dictionary, Ada-LISTA can be flexible to various models, while still providing good generalization. Appendix B contains the proof of the following Theorem.
Theorem 2 (Ada-LISTA – The Applicable Dictionaries).
Consider a trained Ada-LISTA network with a fixed , and noisy input . If the following conditions hold:
1. The diagonal elements of are close to : ;
2. The off-diagonals are bounded: ;
3. is sufficiently sparse: ; and
4. The thresholds satisfy
with , , and ,
then the support of the th iteration of Ada-LISTA is included in the support of , and its values satisfy
An interesting question arising is the following: Once Ada-LISTA has been trained and the matrix is fixed, which dictionaries can be effectively served with the same parameters, without additional training? Theorem 2 reveals that as long as the effective matrix is sufficiently close to the identity matrix, linear convergence is guaranteed. This holds in particular for two interesting scenarios, proven in Appendices C and D:
Random permutations – If Ada-LISTA converges for signals emerging from , it also converges for signals originating from any permutation of ’s atoms.
Noisy dictionaries – If Ada-LISTA converges given a clean dictionary , satisfying , it also converges for noisy models , with some probability, depending on the distribution of .
To the best of our knowledge, Theorem 2 provides the first convergence guarantee in the presence of model variations, claiming that linear rate convergence is guaranteed, depending on the availability of small enough cardinality and low mutual-coherence . Note that the above claim, as in previous work Liu et al. (2019); Zarka et al. (2019), addresses the core capability of reaching linear convergence rate while disregarding both training and generalization errors.
4 Related Work
As already mentioned, the literature discussing LISTA and its successors, is abundant. In this section we aim to discuss relevant work to provide better context to our contribution.
The most relevant work to ours is “robust-ALISTA” Liu et al. (2019), introducing adaptivity to dictionary perturbations. Their work assumes that every signal comes from a different noisy model , where is an interference matrix. For each noisy dictionary this method computes an analytic matrix that minimizes the mutual coherence . Then, and are embedded in the architecture, and the training is performed over the step sizes and the thresholds only, leading to a considerable reduction in the number of trained parameters.
While Robust-ALISTA considers model perturbations only, we show empirically that our method can handle more complicated model deviations, as dictionary permutations and totally random models. Additionally, in terms of computational complexity, robust-ALISTA has a complicated calculation of the analytic matrices during inference time, a limitation that does not exist in our scheme. We refer the reader to Appendix F for a more detailed discussion on the difference between both methods.
As to the theoretical aspect of our study, Chen et al. (2018); Liu et al. (2019); Zarka et al. (2019); Wu et al. (2020) have recently shown that learned solvers can achieve linear convergence, under specific conditions on the sparsity level and mutual coherence. These results are the inspiration behind Theorem 1. This work, however, generalizes these guarantees to a varying model scenario, proving that the same weight matrix can serve different models while still reaching linear convergence.
5 Numerical Results
To demonstrate the effectiveness of our approach, we perform extensive numerical experiments, where our goal is two-fold. First we examine Ada-LISTA on a variety of synthetic data scenarios, including column permutations of the input dictionary, additive noisy versions of it, and completely random input dictionaries. Second, we perform a natural image inpainting experiment, showcasing our robustness to a real-world task
5.1 Synthetic Experiments
We construct a dictionary with random entries drawn from a normal distribution, and normalize its columns to have a unit -norm. Our signals are created as sparse combinations of atoms over this dictionary, . While the reported experiments in this section assume no additive noise, Appendix E presents a series of similar tests with varying levels of noise, showing the same qualitative results. The representation vectors are created by randomly choosing a support of cardinality with Gaussian coefficients, . Instead of using the true sparse representations as ground truth for training, we compute the Lasso solution with FISTA ( iterations, ), using the obtained signals and their corresponding dictionary . This is done in order to maintain a real-world setting, where one does not have access to the true sparse representations. We create in this manner examples for training, and for test. Our metric for comparison between different methods is the MSE (Mean Square Error) between the ground truth and the predicted sparse representations at unfoldings, . In all experiments, the Ada-LISTA weight matrices are both initialized as the identity matrix. In the following set of experiments we gradually diverge from the initial model, given by the dictionary , by applying different degradation/modifications to it.
We start with a scenario in which the columns of the initial dictionary are permuted randomly to create a new dictionary . This transformation can occur in the non-convex process of dictionary learning, in which different initializations might incur a different order of the resulting atoms. Although the signals’ subspace remains intact, learned solvers as LISTA where the dictionary is hard-coded during training, will most likely fail, as they cannot predict the updated support.
Here and below, we compare the results of four solvers: ISTA, FISTA, Oracle-LISTA and Ada-LISTA, all versus the number of iterations/unfoldings, . For each training example in ISTA, FISTA and Ada-LISTA, we create new instances of a permuted dictionary and its corresponding true representation, . We then apply FISTA for iterations and obtain the ground truth representations for the signal . Then ISTA and FISTA are applied for only iterations to solve for the pairs . Similarly, the supervised Ada-LISTA is given the ground truth for training. In Oracle-LISTA we solve a simpler problem in which the dictionary is fixed () for all training examples . The results in Figure 3 clearly show that Ada-LISTA is much more efficient compared to ISTA/FISTA, capable of mimicking the performance of the Oracle-LISTA, which considers a single constant .
In this experiment we aim to show that Ada-LISTA can handle a more challenging case in which the dictionary varies by . Each training signal is created by drawing a different noisy instance of the dictionary and a sparse representation , and solving the FISTA to obtain . ISTA and FISTA receive the pairs , and Ada-LISTA receives the triplet . By vanilla LISTA, we refer to a learned solver that obtains , and trains a network while disregarding the changing models. Oracle-LISTA, as before, handles a simpler case in which the dictionary is fixed, being , and all signals are created from it.
Figure 4 presents the performance of the different solvers with a decreasing SNR (Signal to Noise Ratio) of the dictionary
In this setting, we diverge even further from a fixed model, and examine the capability of our method to handle completely random input dictionaries. This time, for each training example we create a different Gaussian normalized dictionary , and a corresponding representation vector with an increasing cardinality: . The resulting signals, , and their corresponding dictionaries are fed to FISTA to obtain the ground truth sparse representations for training, . We compare the performance of ISTA, FISTA, Ada-LISTA and Oracle-LISTA. Similarly to previous experiments, Ada-LISTA is fed during training with the triplet . Vanilla LISTA cannot handle such variation in the input distribution, and thus it is omitted. For reference, we show the results of Oracle-LISTA in which all of the training signals are created from the same dictionary.
As can be seen in Figure 5, for a small cardinality of , Oracle-LISTA is able to drastically lower the reconstruction error as compared to ISTA and FISTA. This result, however, has already been demonstrated in Gregor and LeCun (2010). Ada-LISTA which deals with a much more complex scenario, still provides a similar improvement over both ISTA and FISTA. As the cardinality increases to , the performance of both learned solvers deteriorates, and the improvement over their non-learned counterparts diminishes.
The last experiment provides a valuable insight on the success of LISTA-like learned solvers. The common belief is that acceleration in convergence can be obtained when the signals are restricted to a union of low-dimensional subspaces, as opposed to the entire signal space. The above experiment suggests otherwise: Although the signals occupy the whole space, Ada-LISTA still achieves improved convergence. This implies that the underlying structure should be only of the signal given its generative model , as opposed to the signal model, . In the above, even if the dictionaries are random, the signals must be sparse combinations of atoms. As this assumption of structure weakens with the increased cardinality, the resulting acceleration becomes less prominent. We believe that this conditional information is the key for improved convergence.
5.2 Natural Image Inpainting
In this section we apply our method to a natural image inpainting task. We assume the image is corrupted by a known mask with a ratio of missing pixels. Thus, the updated objective we wish to solve is
where is a corrupt patch of the same size as the clean one, is a dictionary trained on clean image patches, and represents the mask, being an identity matrix with a percentage of diagonal elements equal to zero. Thus, the dictionary is constant, but each patch has a different (yet known) inpainting mask, and thus the effective dictionary changes for each signal.
We slightly change the formulation of the model described in Section 2, and reverse the roles of the input and learned matrices. Specifically, the updated shrinkage step (Equation 3) for image inpainting is
We consider the mask as part of the input, while the dictionary is learned with the following parameterization:
where are the same size as the dictionary , and initialized by it.
In order to collect natural image patches, we use the BSDS500 dataset Martin et al. (2001) and divide it to and training, validation and test images correspondingly. To train the dictionary , we extract patches at random locations from the train images, subtract their mean and divide by the average standard deviation. The dictionary of size is learned via scikit-learn’s function MiniBatchDictionaryLearning with . To train our network, we randomly pick a subset of training and validation patches. We train the network to perform an image inpainting task with ratio of . Instead of using Ada-LISTA as before, we tweak the architecture described in Equation (18) to unfold the FISTA algorithm, termed Ada-LFISTA, as described in algorithm 1. The input to our network is triplets of the corrupt train patches , their corresponding mask , and the solutions of the FISTA solver applied for iterations on the corrupt signals. The output is the reconstructed representations .
We evaluate the performance of our method on images from the popular Set11, corrupted with the same inpainting ratio of , and compare between ISTA, FISTA and Ada-LFISTA for a fixed number of iterations/unfoldings. We extract all overlapping patches in each image, subtract the mean and divide by the standard deviation, apply each solver, un-normalize the patches and return their mean, and then place them in their correct position in the image and average over overlaps. The quality of the results is measured in PSNR between the clean images and the reconstruction of their corrupt version. The patch-wise validation error versus the the number of unfoldings is given in Figure 7; numerical results are given in Table 1, and select qualitative results are shown in Figure 6 and more in Appendix G. There is a clear advantage to Ada-LFISTA over the non-learned ISTA and FISTA solvers. In this setting of missing pixels, a hard-coded solver with a fixed , such as LISTA, cannot deal with the changing mask of each patch.
We have introduced a new extension of LISTA, termed Ada-LISTA, which receives both the signals and their dictionaries as input, and learns a universal architecture that can cope with the varying models. This modification produces great flexibility in working with changing dictionaries, leveling the playing field with non-learned solvers such as ISTA and FISTA that are agnostic to the entire signal distribution, while enjoying the acceleration and convergence benefits of learned solvers. We have substantiated the validity of our method, both in a comprehensive theoretical study, and with extensive synthetic and real-world experiments. Future work includes further investigation of the discussed rationale, and an extension to additional applications.
Appendix A Proof of Theorem 1
This proof follows the steps from Zarka et al. (2019), with slight modifications to fit our scheme. Following the notations in Theorem 1, denotes the true sparse representation of the signal , and . In addition, we define as the support of a vector.
For any iteration the following hold
The estimated support is contained in the true support,
The recovery error is bounded by
We start by showing that the induction hypothesis holds for . Since we get that the support is empty and the support hypothesis Equation 19 holds. As for the recovery error, we get that
Therefore, to verify Equation 20 we need to show that
Since , for any index we can write
where denotes the th column in and denotes the th element in . Multiplying each side by we get
Since by assumption , the left term becomes . In addition, since, by assumption, there are no more than nonzeros in and is bounded by , we get the following bound
By taking a maximum over we obtain
Since we have assumed that , and
Finally, since , and , we get
as in Equation 20, and therefore the recovery error hypothesis holds for the base case.
Assuming the induction hypothesis holds for iteration , we show that it also holds for the next iteration . We define and denote by the subset of columns in .
Placing , we get
Since , the following holds:
Therefore, Equation 31 becomes
We aim to show that for any , , as the support hypothesis suggests. Since , we can bound the input argument of the soft threshold by
Using the induction assumption on the support, , we can upper bound the first term in the right-hand-side,
Using the induction assumption on the recovery error (Equation 20), we have . Therefore, we get
However, by our assumptions,
and by placing we get
Since is the input to the soft threshold operator , and it is no bigger than the threshold, we get that , and the support hypothesis holds.
We proceed by proving that the recovery error hypothesis also holds (Equation 20). We use the fact that for any scalar triplet, , the soft threshold satisfies
Therefore, following Equation 32 we get
As before, since , we have
Therefore, by using Equation 36 we get
and by placing we obtain
By taking a maximum over , we establish the recovery error hypothesis (Equation 20), concluding the proof. ∎
Appendix B Proof of Theorem 2
We define an effective matrix . In this part, we aim to prove that linear convergence is guaranteed for any dictionary , satisfying two conditions: (i) the diagonal elements of are close to , and (ii) the off-diagonal elements of are bounded.
This proof is based on Appendix A, with the following two modifications: The mutual coherence is replaced with , and the diagonal element is not assumed to be equal to , but rather bounded from below by .
The base case of the induction (Equation 26) now becomes:
Since we assume , and
As , , therefore the induction hypothesis holds for the base case.
Moving to the inductive step, the proof of the support hypothesis remains almost the same, apart from replacing with . This is due to the fact that if , then , and therefore the diagonal elements multiply zero elements.
As to the recovery error hypothesis, we need to upper bound for . Since we need to modify Equation 31:
Using Equation 39 we get that is upper bounded by
which in turn is upper bounded by
Placing results in
Taking a maximum over establishes the recovery error assumption, proving the induction hypothesis. ∎
Appendix C Proof for Random Permutations
We show that if the weight matrix leads to linear convergence for signals generated by , then linear convergence is also guaranteed for signals originating from , where is a permutation matrix. The proof is straightforward, as the permutation matrix does not flip diagonal and off-diagonal elements in the effective matrix . Thus, the mutual coherence does not change and the conditions of Theorem 2 hold, establishing linear convergence.
Appendix D Proof for Noisy Dictionaries
We now consider signals from noisy models, , where , and the model deviations are of Gaussian distribution, . Given pairs of ), we show that Ada-LISTA recovers the original representations , with respect to their model in linear rate.
Theorem 3 (Ada-LISTA Convergence – Noisy Model).
Consider a noisy input , where , . If for some constants , is sufficiently sparse,
and the thresholds satisfy
with , , , , and , then, with probability of at least , the support of the th iteration of Ada-LISTA is included in the support of and its values satisfy
The proof for this theorem consists of two stages. First, we study the effect of model perturbations on the effective matrix , deriving probabilistic bounds for the changes in the diagonal and off-diagonal elements. Then, we place these bounds in Theorem 2 to guarantee linear rate.
We start by bounding the changes in the effective matrix . These deviations modify the off-diagonal elements, which are no longer bounded by , and the diagonal elements that are not equal to anymore. Define as:
This implies is equal to:
Since and the elements in are independent, the expected value of is
To bound the changes in we aim to use Cantelli’s inequality, but first, we need to find the variance of :
In what follows we calculate each term in the right-hand-side, starting with :
Moving on to , we get