A neural network with feature sparsity
Abstract
We propose a neural network model, with a separate linear (residual) term, that explicitly bounds the input layer weights for a feature by the linear weight for that feature. The model can be seen as a modification of socalled residual neural networks to produce a path of models that are featuresparse, that is, use only a subset of the features. This is analogous to the solution path from the usual Lasso (regularized) linear regression. We call the proposed procedure LassoNet and develop a projected proximal gradient algorithm for its optimization. This approach can sometimes give as low or lower test error than a standard neural network, and its feature selection provides more interpretable solutions. We illustrate the method using both simulated and real data examples, and show that it is often able to achieve competitive performance with a much smaller number of input features.
1 Introduction
In many applications, neural networks achieve stateoftheart accuracy. This technology has provided nearhuman performance on many prediction tasks, and left deep marks on entire fields of business and science, to the extent that large computational and engineering efforts are routinely dedicated to neural network training and optimization. However, neural networks are sometimes criticized for their complexity and lack of interpretability. There are many arguments that favor simple models over more complex ones. In many applications (including medical diagnosis, insurance and banking products, flight control, among others) interpretation of the underlying model is a critical requirement. On the other hand, traditional statistical tools, including simple linear models, remain popular because they are simple and explainable, with cheap, efficient computational tools are readily available. As a consequence, there has been growing interest in simplifying and understanding neural networks all the while preserving their performance. This paper attempts to bridge the gap between interpretability and performance, by enforcing sparsity on the input features fed into the neural network.
The Lasso estimate (Tibshirani (1996)) is perhaps the most prominent tool for enforcing sparsity in linear models, thus easing their interpretability. The sparsity pattern provides a direct way to identify the most significant features of the model. Just like the Lasso, our proposed model is controlled by a penalty factor, that is chosen through crossvalidation or from a validation set. In addition to improving explainability, a feature sparse structure may improve the model’s generalization ability, as was confirmed by our experiments.
Figure 1 shows an example, using the Boston housing data consisting of 506 observations of house prices and 13 predictors. We divided the data into training and test sets of size 404 and 102 respectively, and then added 13 Gaussian noise predictors. The test error for the Lasso, a single hidden layer neural network and our proposed procedure LassoNet ^{2}^{2}2We recently discovered a procedure proposed by Webber et al. (2018) with an R language implementation of the same name. However it treats a completely different problem, a linear model with network constraints. With apologies to these authors, the name “LassoNet” seems very appropriate for the method proposed here. are shown. We see that the new procedure achieves the minimum test error at around 13 predictors, and does as well as the standard neural network. Furthermore, it turned out that 11 of the first 13 predictors chosen came from the list of real rather than the artificial predictors.
The outline of this paper is as follows. In Section 2 we briefly review related work on this problem. Section 3 introduces the proposed LassoNet procedure. In Sections 4 and 5 we examine performance of the procedure for simulated and real datasets. The pathwise optimization procedure is detailed in 6, and we end with some extensions and discussion in Section 7.
2 Short review of past work
Sparse neural architectures are a topic of active research. And while much of the literature has focused on full network sparsity (for e.g., computational reasons, or memory and time efficiency), this brief review focusses on featuresparsity. Considering many of today’s highdimensional datasets, e.g. from biomedical or computer vision problems (among many others), it is of particular interest to find a small subset of input features that contribute most of the explanatory power Guyon & Elisseeff (2003).
Pruning procedures. The simple and most popular approach relies on socalled pruning methods, that share the common trait of being performed after model training. Verikas & Bacauskiene (2002) proposed to examine the output sensitivity to input changes due to the removal of individual features. However, sensitivitybased approaches suffer from treating features individually. It would be combinatorially intractable to consider larger groups of features, yet feature importance unavoidably depends on what other features are present in the model, cf section 3.3 of Guyon & Elisseeff (2003). In Han et al. (2015), a thresholding approach is proposed. First, the network is trained in its entirety with an penalty. Then, weights whose magnitude are below a certain threshold are set to zero, and the model is trained from scratch, again, on the remaining features. In addition to lengthening the training time, such a method requires a separate pruning process. However, unifying training and selection would be a desirable property of the model, especially from the selective inference perspective. Reed (1993) provides a more comprehensive survey of pruningrelated methods.
Weight regularization. A somewhat more principled approach is to enforce regularization through penalization. The norm acts as a convex proxy for the norm, directly penalizing the number of active features, and originating in the LassoLasso estimator. There has been growing interest in methods that produce structured sparsity Wen et al. (2016); Yoon & Hwang (2017); Scardapane et al. (2017). These methods make use of the group lasso penalty, a variant of the lasso which induces entire groups of neural weights to be set to zero. Taking the example of Scardapane et al. (2017), which enforces sparsity by applying a group lasso penalty on a specific grouping of the weights, the authors apply a group penalty to the set of weights that originate from a given feature. However, this formulation being nonconvex at the origin, standard gradientbased methods are insufficient, and a thresholding step after optimization is generally required to obtain precise sparsity.
3 The LassoNet proposal
3.1 Formulation
We start with the usual regression data , , with being a quantitative outcome. Let . We assume the model
(1) 
with . Here is a monotone, nonlinear function such as a sigmoid or rectified linear unit, and each is a vector. Our objective is to minimize
(2) 
Extensions to other problems such as classification are straightforward, by replacing the squared error term with a binomial or multinomial loglikelihood.
Note that there are three tuning parameters: the main parameter controls the complexity and sparsity of the fitted model. The secondary parameter controls the relative complexity between the linear and nonlinear parts of the model, with being the default. Typically we will vary over a dense grid that spans the spectrum from dense and sparse, and try a few values for , e.g . We use a validation set or crossvalidation to choose optimal values for and . Finally, the term , while not needed to control complexity, can provide additional sparsity for the . While our algorithm can handle values , we set in all examples of the paper.
3.2 Strategy for path optimization
Our strategy is to optimize (2) for each fixed value , over a wide path of values. This kind of strategy is used in Lasso ()regularized linear regression, where coordinate descent is often used for each fixed value of (see Friedman et al. (2010)). In this approach, optimization is carried out on a fine grid from sparse (large ) to dense (small , using warm starts in for each . This is very effective, since the sparse models are easier to optimize and the sparse regime is often the main (or only) region of interest. Further, the convexity of the Lasso problem ensures that we can find the global minimum for each .
Somewhat to our surprise, we have found that in the (nonconvex) LassoNet problem, this sparsetodense approach does not work well, and can get caught in poor local minima. Instead, we find that a densetosparse strategy is far more effective.
Figure 2 shows an example. With we generated standard independent Gaussian features, and the outcome from the model
(3) 
with , and , giving a signaltonoise ratio of about 3. We generated a test set with 1000 observations for assessing the model performance, set and the number of hidden units . Figure 2 shows the test error for the Lasso, and LassoNet using both strategies. We see that only the densetosparse approach captures the nonlinear signal.
Figure 3 shows the path of solutions as is varied.
4 Simulation study
We carried out a simulation study to compare the performance of LassoNet to the Lasso and a standard single hidden layer neural network. For further exploration, we also implemented a two stage sparse neural network procedure with a linear term, defined as follows:

Compute the Lasso path of solutions for the linear model.

For each of equally spaced solutions along the path (we used )

(a) Compute the residual from the Lasso solution

(b) Fit a neural network to using just the features with nonzero weights in the Lasso solution, giving .

(c) Construct the final solution as

All neural nets in the study had hidden units and used the sigmoid transfer function; the bound parameter for LassoNet was set to 1. We created a training set of size , with 6 different scenarios of the form , with . For the 6th scenario, . We also generated a validation set of size and a test set of size 1000 from the same population. The validation set was used to choose the tuning parameter ( for Lasso, twostage and LassoNet, number of training epochs for the NN) and then chosen model was evaluated on the test set.

Linear ; , with signs chosen at random

Linear + Nonlinear, with strong linear signal: ; ,

Linear + Nonlinear, with weak linear signal: ; ,

Linear + Nonlinear, with weak linear signal and nonhierarchical structure: ; .

Sum of two sigmoid function in all features: .

Friedman’s function: , This is taken from Friedman (1991), where it was used to assess the MARS procedure.
All features were generated as independent standard Gaussians, except in scenario 6 where each entry was chosen uniformly from . In each of settings 2–4, the nonlinear signal is uncorrelated with the linear component. Figure 4 shows the test set mean squared errors over 10 realizations for each method, in the six scenarios.
In the top left (linear model setting), Lasso, twostage and LassoNet perform well, with the neural network lagging behind. LassoNet shows an advantage in scenarios 24, where there the underlying function has both linear and nonlinear components. The nonhierarchical setup has only a mild effect on LassoNet.
Scenario 5 is a setup for the neural network, and nonsurprisingly, it performs best. In Scenarios 6 we see that LassoNet and the two stage procedure outperform the lasso, and provide considerable improvement over a standard neural network.
5 Real data examples
In looking for real data examples, we naturally sought problems where a standard neural network outperformed a linear model fit via the Lasso. Since are are not working with deep (multi hidden layer) nets in this paper, we didn’t consider problems where multilevel feature extraction and convolution can be effective. Rather we considered “flat” learning problems, where a single hidden layer net would be effective. And we also focussed on observational data where the SNR is relatively low, and feature interpretation is likely to be important. To our surprise, we found that for most the problems we looked at (e.g. from the UCI database), the Lasso performed as well as the Neural Net. We wonder what experience others have had in this regard.
5.1 NRTI HIV data
Rhee et al. (2003) study six nucleoside reverse transcriptase inhibitors (NRTIs) that are used to treat HIV1. The target of these drugs can become resistant through mutation, and they compare a collection of models for predicting these drug’s (log) susceptibility, a measure of drug resistance based on the location of mutations. We used the data on the first inhibitor, for which there are sites and samples, and divided the data into roughly equalsized training and test sets. The test set mean squared errors are shown in Figure 5. LassoNet improves upon both the lasso and a neural network, and does best using only about 60 features. Figure 6 shows the number of features in common between Lasso and LassoNet fits, as we move along the solution paths. We see nearly 100% agreement, which is reassuring and aids in the interpretation of the LassoNet model.
5.2 MNIST data
We selected a training set of 1000 1s and 2s from the well known MNIST dataset, each image being a matrix of gray scale values from 0 to 255. In this problem a multilayer deep net is likely to work best, but for illustration here we consider here just single layer net. We applied the lasso, a single hidden layer NN with hidden units, and the LassoNet with the same architecture and . Figure 7 show the results from lasso, NN and LassoNet. For the standard NN we have plotted the image with entries as a measure of the importance of each feature . The LassoNet solution represents the one with about 60 nonzero features, and uncovers the pixels most important for disciminating 1s from 2s.
5.3 Ames housing dataset
This example is taken from Cock (2011) and was used in a Kaggle competition “House Prices: Advanced Regression Techniques’. The goal is to predict housing prices in Ames, Iowa. There are 1460 observations, and 863 predictors after onehot encoding the categorial predictors and removing predictors with missing values. We divided the data into training and test sets, in proportions . Figure 8 shows the test set prediction errors. We see that LassoNet offers a substantial improvement in test error, using only a fraction of the total predictors.
6 Optimization formulation
Our optimization problem consists in minimizing the objective function defined in (2):
(4) 
where our prediction is defined by
for each . For simplicity we have omitted the penalty in (2). Unfortunately, this problem is nonconvex due to the presence of the nonlinear activation function , and to the inequality constraint. However, it can be tackled using the proximal gradient descent algorithm [Parikh et al. (2014); Beck (2017)]. For notation shorthand, we introduce
The proximal gradient descent algorithm is an iterative procedure and we use to denote the parameters, and to denote the stepsize at the th iteration. The algorithm updates via the following:
(5) 
It is easy to find the update rule for . In fact, Eq (5) shows that
Moreover, Eq (5) implies that satisfies the following: for each ,
(6) 
Now, for each , is the solution to the minimization problem defined as righthand side of Eq (6). Despite the problem’s nonconvexity, we are able to develop an efficient iterative algorithm. Our approach relies crucially on the proposition below, whose proof is deferred to Appendix A.
Fix , and . Consider the optimization problem:
(7) 
A sufficient and necessary condition for to be the optimum of problem (Eq (7)) is the following: there exist some and such that

(8) where are the order statistics of the coordinates of .

Proposition 6 naturally leads to an efficient method for solving the update in Eq (6), given Algorithm 1 below.
7 Extensions of the procedure and discussion
There a number of ways that the LassoNet procedure can be generalized. These include:

Multilayer (deep) networks. One can add additional hidden layers, without any change to the constraints. Since the raw features are modelled only in the first hidden layer, feature sparsity will be attained in the same manner.

Nonlinear sparse principal components and autoencoders. Here we predict from itself, and so have connections in the final layer.
These will be topics of future research.
Acknowledgements: We’d like to thank John Duchi and Ryan Tibshirani for helpful comments. Robert Tibshirani was supported by NIH grant 5R01 EB00198816 and NSF grant 19 DMS1208164. A Rlanguage package for LassoNet will be made freely available.
References
 (1)
 Beck (2017) Beck, A. (2017), Firstorder methods in optimization, Vol. 25, SIAM.

Cock (2011)
Cock, D. D. (2011), ‘Ames, iowa: Alternative
to the boston housing data as an end of semester regression project’, Journal of Statistics Education 19(3), null.
https://doi.org/10.1080/10691898.2011.11889627  Friedman (1991) Friedman, J. (1991), ‘Multivariate adaptive regression splines (with discussion)’, Annals of Statistics 19(1), 1–141.
 Friedman et al. (2010) Friedman, J., Hastie, T. & Tibshirani, R. (2010), ‘Regularization paths for generalized linear models via coordinate descent’, Journal of Statistical Software 33, 1–22.
 Guyon & Elisseeff (2003) Guyon, I. & Elisseeff, A. (2003), ‘An introduction to variable and feature selection’, Journal of Machine Learning Research 55, 1157–1182.
 Han et al. (2015) Han, S., Pool, J., Tran, J. & Dally, W. J. (2015), ‘Learning both weights and connections for efficient neural network’, Advances in Neural Information Processing Systems 55, 1135â1143.
 Parikh et al. (2014) Parikh, N., Boyd, S. et al. (2014), ‘Proximal algorithms’, Foundations and Trends® in Optimization 1(3), 127–239.
 Reed (1993) Reed, R. (1993), ‘Pruning algorithms  a survey’, IEEE Transactions on Neural Networks 4, 740–747.
 Rhee et al. (2003) Rhee, S.Y., Gonzales, M. J., Kantor, R., Betts, B. J., Ravela, J. & Shafer, R. W. (2003), ‘Human immunodeficiency virus reverse transcriptase and pro tease sequence database’, Nucleic Acids Research 31, 298–303.
 Scardapane et al. (2017) Scardapane, S., Hussain, A., Uncini, A. & Comminiello, D. (2017), ‘Group sparse regularizations for deep neural networks’, Journal, Neurocomputing, Elsevier Science Publishers 241, 81â89.
 Tibshirani (1996) Tibshirani, R. (1996), ‘Regression shrinkage and selection via the lasso’, Journal of the Royal Statistical Society, Series B 58, 267–288.
 Verikas & Bacauskiene (2002) Verikas, A. & Bacauskiene, M. (2002), ‘Feature selection with neural networks’, Pattern Recognition Letters 23(11), 1323–1335.
 Webber et al. (2018) Webber, M., Striakus, J., Schumacher, M. & Binder, H. (2018), Networkconstrained covariate coefficient and connection sign estimation, Technical report, CORE Discussion Paper 2018/18 OR Bank of Lithuania Discussion Paper.
 Wen et al. (2016) Wen, W., Wu, C., Wang, Y., Chen, Y. & Li, H. (2016), Learning structured sparsity in deep neural networks, in ‘Advances in Neural Information Processing Systems’, pp. 1–10.
 Yoon & Hwang (2017) Yoon, J. & Hwang, S. J. (2017), Combined group and exclusive sparsity for deep neural networks, in ‘Proceedings of the 34th International Conference on Machine LearningVolume 70’, JMLR.org, pp. 3958–3966.
Appendix A Proof of Proposition 6
We start by proving the claim below: for some
(9) 
Indeed, note first that , since otherwise achieves a strictly smaller objective than . Now, we denote . Certainly, we have
(10) 
Moreover, by definition, is the minimum for the optimization problem below:
This optimization problem is convex in . Since Slater’s condition holds we have strong duality for the optimization problem. Hence, there exists some dual variable such that minimizes the Lagrangian function below:
Now, we take the subgradient and get the characterization below: for each ,
(11) 
Now we divide our discussion into two cases:

. By the KKT condition (Eq. (11)), we get that . Now, the fact that for some implies that . Hence . Note that, if , we must have . Now, the fact that having some satisfying is equivalent to .
Summarizing the above discussion, we see for each , must satisfy
(12) 
Now, Eq (10) and Eq (12) together give the desired claim at Eq (9).
Now, back to the proof of the proposition. Define and by
The claim at Eq (9) allows us to reduce the original minimization problem (i.e., Eq (7)) to the problem that finds that minimizes . Note that is a piecewise smooth function. For each , we can compute and get
where is some remainder term that is independent of (but can be dependent of , and ). Hence, is smooth for . Now, define for each ,
(13) 
Clearly, if , then is a local minimum of over . Now we note the observation below:
(14) 
We defer the proof of this observation. An immediate consequence is that the minimum for over is the unique that satisfies .
Now, we prove the observation at Eq (14). Suppose for , we have . Then
(15) 
which justifies the uniqueness of satisfying . We first show that for all . By definition of , , which, by definition of , is equivalent to
Since , this also implies that
Thus, it gives . This proves that for all . Similarly, one can prove for all , and we omit the proof of this part.
References
 (17)
 Beck (2017) Beck, A. (2017), Firstorder methods in optimization, Vol. 25, SIAM.

Cock (2011)
Cock, D. D. (2011), ‘Ames, iowa: Alternative
to the boston housing data as an end of semester regression project’, Journal of Statistics Education 19(3), null.
https://doi.org/10.1080/10691898.2011.11889627  Friedman (1991) Friedman, J. (1991), ‘Multivariate adaptive regression splines (with discussion)’, Annals of Statistics 19(1), 1–141.
 Friedman et al. (2010) Friedman, J., Hastie, T. & Tibshirani, R. (2010), ‘Regularization paths for generalized linear models via coordinate descent’, Journal of Statistical Software 33, 1–22.
 Guyon & Elisseeff (2003) Guyon, I. & Elisseeff, A. (2003), ‘An introduction to variable and feature selection’, Journal of Machine Learning Research 55, 1157–1182.
 Han et al. (2015) Han, S., Pool, J., Tran, J. & Dally, W. J. (2015), ‘Learning both weights and connections for efficient neural network’, Advances in Neural Information Processing Systems 55, 1135â1143.
 Parikh et al. (2014) Parikh, N., Boyd, S. et al. (2014), ‘Proximal algorithms’, Foundations and Trends® in Optimization 1(3), 127–239.
 Reed (1993) Reed, R. (1993), ‘Pruning algorithms  a survey’, IEEE Transactions on Neural Networks 4, 740–747.
 Rhee et al. (2003) Rhee, S.Y., Gonzales, M. J., Kantor, R., Betts, B. J., Ravela, J. & Shafer, R. W. (2003), ‘Human immunodeficiency virus reverse transcriptase and pro tease sequence database’, Nucleic Acids Research 31, 298–303.
 Scardapane et al. (2017) Scardapane, S., Hussain, A., Uncini, A. & Comminiello, D. (2017), ‘Group sparse regularizations for deep neural networks’, Journal, Neurocomputing, Elsevier Science Publishers 241, 81â89.
 Tibshirani (1996) Tibshirani, R. (1996), ‘Regression shrinkage and selection via the lasso’, Journal of the Royal Statistical Society, Series B 58, 267–288.
 Verikas & Bacauskiene (2002) Verikas, A. & Bacauskiene, M. (2002), ‘Feature selection with neural networks’, Pattern Recognition Letters 23(11), 1323–1335.
 Webber et al. (2018) Webber, M., Striakus, J., Schumacher, M. & Binder, H. (2018), Networkconstrained covariate coefficient and connection sign estimation, Technical report, CORE Discussion Paper 2018/18 OR Bank of Lithuania Discussion Paper.
 Wen et al. (2016) Wen, W., Wu, C., Wang, Y., Chen, Y. & Li, H. (2016), Learning structured sparsity in deep neural networks, in ‘Advances in Neural Information Processing Systems’, pp. 1–10.
 Yoon & Hwang (2017) Yoon, J. & Hwang, S. J. (2017), Combined group and exclusive sparsity for deep neural networks, in ‘Proceedings of the 34th International Conference on Machine LearningVolume 70’, JMLR.org, pp. 3958–3966.