Explainable Neural Networks based on Additive Index Models
Machine Learning algorithms are increasingly being used in recent years due to their flexibility in model fitting and increased predictive performance. However, the complexity of the models makes them hard for the data analyst to interpret the results and explain them without additional tools. This has led to much research in developing various approaches to understand the model behavior. In this paper, we present the Explainable Neural Network (xNN), a structured neural network designed especially to learn interpretable features. Unlike fully connected neural networks, the features engineered by the xNN can be extracted from the network in a relatively straightforward manner and the results displayed. With appropriate regularization, the xNN provides a parsimonious explanation of the relationship between the features and the output. We illustrate this interpretable feature–engineering property on simulated examples.
Neural networks (NNs) and ensemble algorithms such as Gradient Boosting Machines (GBMs) and Random Forest (RFs) have become popular in recent years due to their predictive power and flexibility in model fitting. They are especially useful with large data sets where it is difficult to do handcrafted variable selection and feature engineering. Further, in these situations, they have substantially better predictive performance compared to traditional statistical methods. Despite these advantages, there has been reluctance to fully adopt them. One of the primary barriers to widespread adoption is the “black box” nature of such models. The models are very complex and cannot be written down explicitly. It is therefore difficult for a modeler to explain the relationships between the input features and the response or more generally understand the model’s behavior. However, the ability to interpret a model and explain its behavior is critical in certain industries such as medicine and health care that deal with high risk or in banking and finance that are strongly regulated. For instance, in banking, regulators require that the input-output relationships are consistent with business knowledge and the model includes key economic variables that have to be used for stress testing.
These challenges have led to a lot of research recently in developing tools to “open up the black box”. There are, broadly speaking, three inter-related model–based areas of research: a) global diagnostics (Sobol & Kucherenko (2009), Kucherenko (2010)); b) local diagnostics ( Sundararajan et al. (2017), Ancona et al. (2018)); and c) development of approximate or surrogate models that may be easier to understand and explain. These models which may be either global (Hinton et al. (2015), Bucilua et al. (2006), Tan et al. (2018)) or local (Hu et al. (2018)) in nature. There are also efforts to understand neural networks using visualization–based techniques such as those described in Kahng et al. (2017) or Olah et al. (2017).
In this paper, we propose a flexible, yet inherently explainable, model. More specifically, we describe a structured network that imposes some constraints on the network architecture and thereby provides better insights into the underlying model. We refer to it as explainable neural network, or xNN. The structure provides a means to understand and describe the features engineered by the network in terms of linear combinations of the input features and univariate non-linear transformations.
We use the terms “interpretable” and “explainable” interchangeably in this paper although, strictly speaking, they have different meanings. The former refers to the ability to understand and interpret the results to yourself; and the latter is the ability to explain the results to someone else. So interpretability can be viewed as a precursor to explainability. But we do not make that distinction in this paper.
Explainability by itself is not enough without also considering predictive performance. For instance, a linear model is very explainable but it is likely to have poor performance approximating a complex surface. In the simple examples considered in the paper, the xNNs have excellent predictive performance. But additional research is needed on more complex examples, and this is being currently pursued.
Feedforward neural networks typically consist of fully connected layers, i.e., the output of each node on layer is used as input for each node on layer . By limiting the connections between nodes, we can give a feedforward neural network structure that can be exploited for different purposes. For example, Tsang et al. (2018) considered a structure to detect interactions among input features in the presence of features’ main effects. In this paper, we propose a structured neural network designed to be explainable, meaning that it is relatively easy to describe the features and nonlinear transformations learned by the network via the network structure. It is based on the concept of additive index models (Ruan & Yuan (2010), Yuan (2011) ) and is related to projection pursuit and generalized additive models (Hastie & Tibshirani (1986)).
The remainder of the paper is as follows. In Section 2, we review additive index model and introduce the explainable neural network architecture. In Section 4, we illustrate how the components of the xNN may be used to describe the engineered features of the input variables the network learns. Section 5 discusses several practical considerations that arise in using such networks in practice. Finally, we provide additional examples of trained xNN models in Section 6.
2 Additive Index Models
The formal definition of a additive index model is given in (1):
where the function on the LHS can be expressed as a sum of smooth functions (Ruan & Yuan, 2010). These univariate functions are each applied to a linear combination of the input features (). The coefficients are often referred to as projection indices and the ’s are referred to as ridge functions, following Friedman & Stuetzle (1981). See also Hastie & Tibshirani (1986) for the related notion of generalized additive models. The additive index model in (1) provides a flexible framework for approximating complex functions. In fact, as shown in Diaconis & Shahshahani (1984), the additive index models can approximate any multivariate function with arbitrary accuracy provided , the number of ridge functions, is sufficiently large. In practice, additive index models can be fit using penalized least squares methods to simultaneously fit the model and select the appropriate number of ridge functions (Ruan & Yuan (2010)). See also Yuan (2011) for a discussion of identifiability issues surrounding such models.
3 Explainable Neural Network Architecture (xNN)
The Explainable Neural Network provides an alternative formulation of the additive index model as a structured neural network. It also provides a direct approach for fitting the model via gradient-based training methods for neural networks. The resulting model has built-in interpretation mechanisms as well as automated feature engineering. We discuss these mechanisms in more detail in Section 4. Here, we describe the architecture of the xNN.
We define a modified version of the additive index model in (1) as follows:
Although the shift parameter and the scale parameters ’s are not identifiable, they are useful for the purposes model fitting: selecting an appropriate number of ridge functions through regularization.
The structure of an xNN is designed to explicitly learn the model given in equation (2). Figure 1 illustrates the architecture of an xNN. The input layer is fully connected to the first hidden layer (called the projection layer), which consists of nodes (one for each ridge function.). The weights of the node in the first hidden layer corresponds to the coefficients () of the input to the corresponding ridge function. The projection layer uses a linear activation function, to ensure that each node in this layer learns a linear combination of the input features. The output of each node in the projection layer is used as the input to exactly one subnetwork.
Subnetworks are used to learn the ridge functions, . The external structure of the subnetworks is essential to the xNN. Each subnetwork must have univariate input and output, and there must be no connections between subnetworks. The internal structure of subnetworks is less critical, provided that the subnetworks have sufficient structure to learn a broad class of univariate functions. Subnetworks typically consist of multiple fully-connected layers and use nonlinear activation functions. More details are discussed in Section 5.3.
The combination layer is the final hidden layer of the xNN, and consists of a single node. The inputs of the node are the univariate activations of all of the subnetworks. The weights learned correspond to the ’s in equation (2), and provide a final weighting of the ridge functions. A linear activation function is used on this layer, so the output of the network as a whole is a linear combination of the ridge functions. (Note: A non–linear activation function may easily be used on the combination layer instead of a linear activation. This changes to the formulation given in (2) by wrapping the LHS in a further link function, as with generalized linear models. We do not explore this generalization in detail here.)
The neural network based formulation of the additive index model provides some advantages over the traditional approach in the statistics literature. First, it may be trained using the same mini–batch gradient–based methods, allowing the xNN formulation to easily be trained on datasets that may be too large to fit in memory at the same time. Further, the neural network formulation allows the xNN to take advantage of the advancements in GPU computing used to train neural networks in general. Finally, the neural network formulation allows for straightforward computation of partial derivatives of the function learned by the xNN. This supports the ability to carryout derivative–based analysis techniques using the xNN, without needing to rely on finite difference approximations and the difficulties that these may cause. Some techniques that may be employed are presented in Sobol & Kucherenko (2009) and Kucherenko (2010).
In the next section, we illustrate how the structures built into the xNN, namely the projection layer and subnetworks, provide a mechanism to explain the function learned by such a network.
4 Visualization and Explainability of the xNN
We now illustrate how visualization of xNN components can be used to aid in explainability. We consider a simple toy example based on the first three Legendre polynomials, shown in Figure 2 and defined in (3). These polynomials are orthogonal on the the interval and have a range of over the same interval. The exact form of these functions is not of particular interest except for the fact that they provide distinct linear, quadratic, and cubic functions on a similar scale and are orthogonal.
We simulated five independent variables, from a Uniform distribution on . We then generated via
where are the Legendre polynomials as described in (3). This leaves as noise variables.
We then built an xNN model with 5 subnetworks and them on all five features (). Only the strength of the penalty on the projection and output layers were tuned. The resulting xNN was used to generate the summaries that follow.
4.1 Visualizing Ridge Functions
Figure 3 shows the ridge functions. Row represents subnetwork for . The first column illustrates the univariate functions learned by subnetwork , scaled by . These plots illustrate the univariate, non–linear transformations learned by the xNN in training. The second column displays the values of , the projection coefficients. The projection coefficients explain which combination of input features is used as input to each of the ridge functions. In this way, the plot displays the most relevant features of the network: the scaled ridge functions and the projection coefficients.
In this example, we see from Figure 3 and 4 that Subnetwork 1 has learned the cubic Legendre function (), and from the second column of Figure 3, only has a non-zero coefficient in the input to this subnetwork. Subnetwork 2 has learned the quadratic function (), and only the coefficient of is nonzero. Subnetwork 5 has learned the linear function (), and only the coefficient of is non-zero. The other subnetworks (3 and 4) are not needed, and are set to zero by using an penalty on the ridge function weights ( in (2)).
4.2 Visualizing Univariate Effects
The plot shown in Figure 4 illustrates the feature-centric view of the xNN, which we refer to as conditional effects. In this view, the th row summarizes the xNN’s treatment of the th feature. In the first column, each subnetwork’s treatment of feature is plotted in row , calculated via Each dotted line represents one such subnetwork, while the bold, solid line represents the effect of the network as a whole on feature . This is calculated via , and is the sum of the conditional effects of the individual subnetworks. This is equivalent to plotting . If the data have been standardized (as is typical in this case), this is equivalent to plotting .
The second column of Figure 4 shows the projection coefficient of feature for each of the subnetworks. This shows which ridge functions are used to describe the effects of .
In this particular example, we see that the only nonzero coefficient of is in the projection for subnetwork 5, the linear function, and that the conditional effect on is linear. Similarly, the only nonzero coefficient of appears in subnetwork 2, which learned a quadratic function. The only nonzero coefficient of is in subnetwork 1, which has learned the cubic function (). The two extraneous variables, and , have no non-zero coefficients, so the overall conditional effect of these variables is constant.
It should be mentioned that the conditional effects plot shows some information that is redundant with the subnetwork–centric view. Nonetheless, the alternate view can be useful in understanding the role each feature plays in the predictions of the xNN model.
In this toy example, the effect of each feature is represented by exactly one ridge function. In situations with more complex behavior, multiple ridge functions may be involved in representing the effect of a particular variable, and often are in more complex situations. Furthermore, in under-regularized networks, the effects of each variable may be be modeled by the contributions of several subnetworks. This behavior is displayed in the examples in Section 6.
5 Practical Considerations
In this section, we consider some of the practical considerations that arise when using such models. These include a brief discussion on the difference between model recoverability and explainability, regularization of the xNN needed to learn a parsimonious model, and the structure of the subnetworks.
5.1 Model Recoverability and Explainability
In practice, fitted xNN models exist on a spectrum of model recoverability while retaining a high degree of explainability. By model recoverability, we refer to the ability to recover the underlying generative mechanisms for the data., and explainability refers to the xNN’s ability to provide an explanation of the mechanisms used by the network to approximate a complex multivariate function, even if these mechanisms do not faithfully recover the underlying data generating process. With proper regularization, as discussed in Section 5.2, the representation is parsimonious and straightforward to interpret. The example discussed previously in Section 4 illustrates a situation where the xNN has high model recoverability, meaning that it has clearly learned the underlying generating process. In practice, this is not always be the case, as the data–generating process may not be fully described by the additive index model. In Section 6.2, we see such an example where the model is explainable even though it does not have higher model recoverability.
In practice, the user will never know on which end of the spectrum a given xNN sits. However, unlike other popular network structures (such as feedforward networks) or tree-based methods, the xNN has a built-in mechanism to describe the complex function learned by the network in the relatively simple terms of projections and univariate ridge functions that ensure the model is explainable, regardless of where it may fall on the model recoverability spectrum.
Finally, note that in certain circumstances, model recoverability may not be desirable. If the data generating process is highly complex, the explainable xNN is likely to be more easily understood given its additive nature. The xNN is especially easy to understand if it has been properly regularized.
5.2 Regularization and Parsimony
The overall explainability of the network can be enhanced by using an penalty on both the first and last hidden layers during training. That is, both the projection coefficients ( ’s) and the ridge function weights ( ’s) are penalized. When the strength of the penalty is properly tuned, this can produce a parsimonious model that is relatively easily explained.
An penalty on the first hidden layer forces the projection vectors to have few non-zero entries, meaning that each subnetwork (and corresponding ridge function) is only applied to a small set of the variables. Similarly, an penalty on the final layer serves to force to zero in situations where fewer subnetworks are needed in the xNN than are specified in training.
5.3 Subnetwork Structure
In principle, the subnetwork structure must be chosen so that each subnetwork is capable of learning a large class of univariate functions. In our experience, however, both the explainability and predictive performance of the network are not highly sensitive to the subnetwork structure. In our simulations, we have found that using subnetworks consisting of two hidden layers with structures such as [25, 10] or even [12,6] with nonlinear activation functions (tanh, e.g.) are sufficient to learn sufficiently flexible ridge functions in fitting the models.
5.4 xNN as a Surrogate Model
While the xNN architecture may be used as an explainable, predictive model built directly from data, it may also be used as a surrogate model to explain other nonparametric models, such as tree-based methods and feedforward neural networks, called a base model. Because the xNN is an explainable model, we may train an xNN using the input features and corresponding response values predicted by the base model. We then may use the xNN to explain the relationships learned by the base model. For further discussion of surrogate models, see Hinton et al. (2015), Bucilua et al. (2006), or Tan et al. (2018). The use of more easily interpretable surrogate models to help interpret a complex machine learning model is similar to the field of computer experiments, where complicated computer simulations of physical systems are studied using well–understood statistical models, as described in Fang, Li & Sudjianto (2005) and Bastos & O’Hagan (2009). In computer experiments, the input to the computer simulation may be carefully designed to answer questions of interest using these statistical models, where as the complex ML models often restricted to observational data.
6 Simulation Examples
In this section, we illustrate the behavior of xNN networks with two simulations. In the first, data are generated from a model that follows the additive index model framework. This is an example where the trained xNN has high model recoverability, meaning it recovers correctly the data generating mechanism. The second simulation does not follow the additive index model framework, yet the trained xNN is still explainable, in the sense that the xNN still provides a clear description of the mechanisms the xNN learns to approximate the underlying response surface.
6.1 Example 1: Linear Model with Multiplicative Interaction
We simulate six independent variables, from independent Uniform distributions on . We then generate via
This is a linear model with a multiplicative interaction. The variable, is left as a noise feature. While this model does not, at first glance, fit the additive index model framework, we note that a multiplicative interaction may be represented as the sum of quadratic functions, as shown in (6).
Therefore, this model may be exactly represented by an xNN. We trained the xNN using 20 subnetworks. The network achieved a mean squared error of 0.0028 the holdout set, close to the simulation lower bound of 0.0025. The resulting active ridge functions are illustrated in Figure 5. (By active ridge functions, we mean those functions that are not constant.) Subnetwork 9 learned a linear ridge function, and has a relatively large projection coefficient for . Subnetworks 2, 4, 5, and 16 learned quadratic ridge functions. Based on the projection coefficients, we see that subnetworks 2 and 5 are used to represent the contributions of and , respectively. Subnetworks 4 and 16 combine to represent the interaction . Both are quadratic. The two features have the same projection coefficients in subnetwork 16, while they have projection coefficients of opposite signs in subnetwork 4. This is exactly the representation of an interaction term described in equation (6). Thus, this xNN has both high model recoverability and a high degree of explainability.
Figure 6 illustrates the conditional effects of each network on each of the predictors. We see, as expected, a linear marginal effect on and quadratic effects on and . It is notable that the conditional effects plots for both and show no conditional effect. In the case of such interactions, this is expected. In this model, if we condition on e.g. , then will show no effect on the response. Similarly, we see no effect of when conditioning on .
6.2 Example 2: Non-Linear Model
We simulate four independent variables, from independent Uniform distributions on . We then generate via
Both and are left as noise variables. We then fit an xNN with 10 subnetworks and a subnet structure of [12,6] with tanh activation. The network achieved a mean squared error of 0.0122 on a holdout test set, close to the simulation lower bound of 0.01.
Note that this generating model does not fit the additive index model framework. In this example, the trained xNN is explainable despite having low model recoverability. Although the xNN cannot recover the data generating process, it still fits the data well, and clearly explains the mechanisms it uses to do so, by displaying the projection coefficients and learned ridge functions.
Figure 7 shows two ridge functions, represented by subnetworks 2 and 5. Both subnetworks have non-zero coefficients of and , although they have the same sign in Subnetwork 5, and opposite signs in Subnetwork 2. We see that the xNN approximates the simulated function with the function , where and are the two ridge functions learned by subnetworks 2 and 5, respectively.
Figure 8 shows the Note that subnetwork 3 has learned a small non-zero coefficients for , however, the corresponding ridge function is constant at zero, so does not contribute to the output. This type of behavior may occur when the xNN is slightly under regularized.
We have proposed an explainable neural network architecture, the xNN, based on the additive index model. Unlike commonly used neural network structures, the structure of the xNN describes the features it learns, via linear projections and univariate functions. These explainability features have the attractive feature of being additive in nature and straightforward to interpret. Whether the network is used as a primary model or a surrogate for a more complex model, the xNN provides straightforward explanations of how the model uses the input features to make predictions.
Future work on the xNN will study the overall predictive performance of the xNN compared to other ML models, such as GBM and unconstrained FFNNs. We will also study the predictive performance lost when using the xNN as a surrogate model for more complex models.
- Ancona et al. (2018) Ancona, M., Ceolini, E., Oztireli, C. & Gross, M. (2018), Towards better understanding of gradient-based attribution methods for deep neural networks, in ‘6th International Conference on Learning Representations’.
- Bastos & O’Hagan (2009) Bastos, L. S. & O’Hagan, A. (2009), ‘Diagnostics for gaussian process emulators’, Technometrics pp. 425–438.
- Bucilua et al. (2006) Bucilua, C., Caruana, R. & Niculescu-Mizil, A. (2006), Model compression, in ‘ICDM’.
- Diaconis & Shahshahani (1984) Diaconis, P. & Shahshahani, M. (1984), ‘On nonlinear functions of linear combinations’, SIAM J. Sci. and Stat. Comput 5(1), 175–191. https://doi.org/10.1137/0905013.
- Fang et al. (2005) Fang, K.-T., Li, R. & Sudjianto, A. (2005), Design and Modeling for Computer Experiments, Chapman and Hall/CRC.
- Friedman & Stuetzle (1981) Friedman, J. H. & Stuetzle, W. (1981), ‘Projection pursuit regression’, Journal of the American Statistical Association 76(376), 817–823.
- Hastie & Tibshirani (1986) Hastie, T. & Tibshirani, R. (1986), ‘Generalized additive models’, Statist. Sci. 1(3), 297–310.
- Hinton et al. (2015) Hinton, G., Vinyals, O. & Dean, J. (2015), Distilling the knowledge in a neural network, in ‘NIPS Deep Learning Workshop’.
- Hu et al. (2018) Hu, L., Chen, J., Nair, V. N. & Sudjianto, A. (2018), ‘Locally interpretable models and effects based on supervised partitioning (lime-sup)’, arXiv preprint arXiv:1806.00663 .
Kahng et al. (2017)
Kahng, M., Andrews, P. Y., Kalro, A. & Chau, D. H. (2017), ‘Activis: Visual exploration of industry-scale deep
neural network models’, CoRR abs/1704.01942.
- Kucherenko (2010) Kucherenko, S. (2010), ‘A new derivative based importance criterion for groups of variables and its link with the global sensitivity indices’, Computer Physics Communications 181(7), 1212–1217.
- Olah et al. (2017) Olah, C., Mordvintsev, A. & Schubert, L. (2017), ‘Feature visualization’, Distill . https://distill.pub/2017/feature-visualization.
- Ruan & Yuan (2010) Ruan, L. & Yuan, M. (2010), Dimension reduction and parameter estimation for additive index models.
- Sobol & Kucherenko (2009) Sobol, I. & Kucherenko, S. (2009), ‘Derivative based global sensitivity measures and their link with global sensitivity indices’, Mathematics and Computers in Simulation (MATCOM) 79(10), 3009–3017.
- Sundararajan et al. (2017) Sundararajan, M., Taly, A. & Yan, Q. (2017), ‘Axiomatic attribution for deep networks’, arXiv preprint arXiv:1703.01365 .
- Tan et al. (2018) Tan, S., Caruana, R., Hooker, G. & Gordo, A. (2018), ‘Transparent model distillation.’, arXiv preprint arXiv:1801.08640 .
- Tsang et al. (2018) Tsang, M., Cheng, D. & Liu, Y. (2018), ‘Detecting statistical interactions from neural network weights’, International Conference on Learning Representations . accepted as poster.
- Yuan (2011) Yuan, M. (2011), ‘On the identifiability of additive index models’, Statistica Sinica 21(4), 1901–1911.