Tree ensembles, such as random forest and boosted trees, are renowned for their high prediction performance, whereas their interpretability is critically limited. In this paper, we propose a post processing method that improves the model interpretability of tree ensembles. After learning a complex tree ensembles in a standard way, we approximate it by a simpler model that is interpretable for human. To obtain the simpler model, we derive the EM algorithm minimizing the KL divergence from the complex ensemble. A synthetic experiment showed that a complicated tree ensemble was approximated reasonably as interpretable.
Making Tree Ensembles Interpretable
Satoshi Hara email@example.com
National Institute of Informatics, Japan
JST, ERATO, Kawarabayashi Large Graph Project
Kohei Hayashi firstname.lastname@example.org
National Institute of Advanced Industrial Science and Technology, Japan
Ensemble models of decision trees such as random forests (Breiman, 2001) and boosted trees (Friedman, 2001) are popular machine learning models, especially for prediction tasks. Because of their high prediction performance, they are one of the must-try methods when dealing with real problems. Indeed, they have attained high scores in many data mining competitions such as web ranking (Mohan et al., 2011). These tree ensembles are collectively referred to as Additive Tree Models (ATMs) (Cui et al., 2015).
A main drawback of ATMs is in interpretability. They divide an input space by a number of small regions and make prediction depending on a region. Usually, the number of the regions they generate is over thousand, which roughly means that there are thousands of different rules for prediction. Non-expert people cannot understand such tremendous number of rules. A decision tree, on the other hand, is well known as one of the most interpretable models. Despite weak prediction ability, the number of regions generated by a single tree is drastically small, which makes the model transparent and understandable.
Obviously, there is a tradeoff between prediction performance and interpretability. Eto et al. (2014) proposed a simplification method of a tree model that prunes redundant branches by approximated Bayesian inference. Wang et al. (2015) studied a similar approach with a richly-structured tree model. Although these approaches certainly improve interpretability, prediction performance is inevitably degenerated, especially when a drastic simplification is needed.
Motivated by the above observation, we study how to improve the interpretability of ATMs. We say an ATM is interpretable if the number of regions is sufficiently small (say, less than ten). Our goal is then formulated as
reducing the number of regions, while
minimizing model error.
To satisfy these contradicting requirements, we propose a post processing method. Our method works as follows. First, we learn an ATM in a standard way, which generates a number of regions (Figure 1(b)). Then, we mimic this by a simple model where the number of regions is fixed as small (Figure 1(c)). We refer to the former as the prediction model or model P, and the latter as interpretation model or model I.
Our contributions are summarized as follows.
Separation of prediction and interpretation. We prepare two different models: model P for prediction and model I for interpretation. This idea balances requirements 1 and 2.
Reformulation of ATMs. We reinterpret an ATM as a probabilistic generative model. With this change of perspective on ATMs, we can model an ATM as a mixture-of-experts (Jordan & Jacobs, 1994). In addition, this formulation induces the following optimization algorithm.
Optimization algorithm. To obtain model I that is close to model P, we consider the KL divergence between models I and model P. To minimize this, we derive the EM algorithm.
For single tree models, many studies regarding interpretability have been conducted. One of the most widely used methods would be a decision tree such as CART (Breiman et al., 1984). Oblique decision trees (Murthy et al., 1994) and Bayesian treed linear models (Chipman et al., 2003) extended the decision tree by replacing single-feature partitioning with linear hyperplane partitioning and the region-wise constant predictor with the linear model. While using hyperplanes and linear models improve the prediction accuracy, they tend to degrade the interpretability of the models owing to their complex structure. Eto et al. (2014) and Wang et al. (2015) proposed tree-structured mixture-of-experts models. These models aim to derive interpretable models while maintaining prediction accuracies as much as possible. We note that all these researches aimed to improve the accuracy–interpretability tradeoff by building a single tree. In this sense, although they treated the tree models, they are different from our study in the sense that we try to interpret ATMs.
The most relevant study compared with ours would be inTrees (interpretable trees) proposed by Deng (2014). Their study tackles the same problem as ours, the post-hoc interpretation of ATMs. The inTrees framework extracts rules from ATMs by treating tradeoffs among the frequency of the rules appearing in the trees, the errors made by the predictions, and the length of the rules. The fundamental difficulty on inTrees is that its target is limited to the classification ATMs. Regression ATMs are first transformed into the classification one by discretizing the output, and then inTrees is applied to extract the rules. The number of discretization level remains as a user tuning parameter, which severely affects the resulting rules. In contrast, our proposed method can handle both classification and regression ATMs.
For , denotes the set of integers . For a statement , denotes the indicator of , i.e. if is true, and if is false.
An ATM is an ensemble of decision trees. For simplicity, we consider a simplified version of ATMs. Let be a -dimensional feature vector. In the paper, we focus on the regression problem where the target domain . Let be the output from the -th decision tree for an input . The output of of the ATM is then defined as a weighted sum of all the tree outputs with weights .
The function , the output from tree , is represented as . Here, denotes the number of leaves in the tree , denotes the index of the leaf node in the tree , and denotes the input region specified by the leaf node . Note that the regions are non-overlapping, i.e. for . Using this notation, ATM is written as:
where for and .
Recall that our goal is to make an interpretation model I by
reducing the number of regions, while
minimizing model error.
Based on (1), now requirement 1 corresponds to reducing the number of regions G. Combining this with requirement 2, the problem we want to solve is defined as follows.
Problem 1 (ATM Interpretation)
Approximate ATM (model P) (1) by using a simplified model (model I) with only regions:
To solve this approximation problem, we need to optimize both the predictors and the regions . The difficulty is that is a non-numeric parameter and it is therefore hard to optimize.
We resolve the difficulty of handling the region parameter by interpreting ATM as a probabilistic generative model. For simplicity, we consider the axis-aligned tree structure. We then adopt the next two modifications. These modifications can also be extended to oblique decision trees (Murthy et al., 1994).
Binary Expression of Feature : Let be a set of split rules of the tree where is the number of internal nodes. At the internal node of the tree , the input is split by the rule or . Let be the set of all the split rules where . We then define the binary feature by for . Figure 2 shows an example of the binary feature .
Generative Model Expression of Regions: As shown in Figure 2, one region can generate multiple binary features. We can thus interpret the region as a generative model of a binary feature using a Bernoulli distribution: . The region can then be expressed as a pattern of the Bernoulli parameter : (or ) means that the region satisfies (or ), while means that both and are not relevant to the region . In the example of Figure 2, the features generated by are and , and the Bernoulli parameter is where is in between 0 and 1.
Using the Bernoulli parameter expression of regions, we can express the probability that the predictor and the binary feature are generated as
where , , and .
where , and
Here, for each , , , and . We note that the probabilistic model (3) is no longer an ATM but a prediction model with a predictor and a set of rules specified by .
where is some input generation distribution, The finite sample approximation of (4) is
where , and are determined from the input .
where , , and is an arbitrary distribution on (Beal, 2003). The EM algorithm is then formulated as alternating maximization with respect to (E-step) and the parameters (M-step).
In E-Step, we fix the values of , , , and , and we maximize the lower-bound (6) with respect to the distribution , which yields
In M-Step, we fix the distribution , and maximize the lower-bound (6) with respect to , , , and . The maximization over is a weighted multinomial logistic regression and can be solved efficiently, e.g. by using conjugate gradient method:
The maximization over can be derived analytically as
The maximization over and can also be derived as
We evaluated the performance of the proposed method on a synthetic data and on an energy efficiency data (Tsanas & Xifara, 2012)111http://archive.ics.uci.edu/ml/datasets/Energy+efficiency. For both datasets, we set . As model P, we used XGBoost (Chen & Guestrin, 2016).
We compared the proposed method with a decision tree. We used DecisionTreeRegressor of scikit-learn in Python. The optimal tree structure is selected from several candidates using 5-fold cross validation.
We prepared three i.i.d. datasets as , , and , which are used for building an ATM, training model I, and evaluating the quality of model I, respectively.
Synthetic Data: We generated regression data as follows:
For each of , , and , we generated 1000 samples.
Energy Efficiency Data: The energy efficiency dataset comprises 768 samples and eight numeric features, i.e., Relative Compactness, Surface Area, Wall Area, Roof Area, Overall Height, Orientation, Glazing Area, and Glazing Area Distribution. The task is regression, which aims to predict the heating load of the building from these eight features. In the experiment, we used of data points as , another as , and the remaining as .
The found rules by the proposed method are shown in Figure 1(c) (synthetic data), and in Table 1 (energy efficiency data). On synthetic data, the proposed method could find the correct data structure as in Figure 1(a). On energy efficiency data, Table 1 shows that the found rules are easily interpretable and would be appropriate; the second and the third rules indicate that the predicted heating load is small when Relative Compactness is less than 0.75, while the last rule indicates that is large when Relative Compactness is more than 0.75. These resulting rules are intuitive in that the load is small when the building is small, while the load is large when the building is huge. Hence, from these simplified rules, we can infer that ATM is learned in accordance with our intuition about data.
Table 2 summarizes the evaluation results of the proposed method and a decision tree. On both data, the proposed method attained a reasonable prediction error using only rules. In contrast, the decision tree tended to generate more than 10 rules. On the synthetic data, the decision tree attained a smaller error than the proposed method while generating 15 rules which are nearly four times more than the proposed method (Figure 3). On the energy efficiency data, the decision tree scored a significantly worse error indicating that the found rules may not be reliable, and thus inappropriate for the interpretation purpose. From these results, we can conclude that the proposed method would be preferable for interpretation because it provided simple rules with reasonable predictive powers.
|, , ,|
|, , , , ,|
|, , , ,|
|Synthetic||744 / 0.01||4 / 0.02||15 / 0.01|
|Energy||100 / 0.22||4 / 20.19||37 / 168.19|
We proposed a post processing method that improves the interpretability of ATMs. The difficulty of ATM interpretation comes from the fact that ATM divides an input space into more than a thousand of small regions. We assumed that the model is interpretable if the number of regions is small. Based on this principle, we formulated the ATM interpretation problem as an approximation of the ATM using a smaller number of regions.
There remains several open issues. For instance, there is a freedom on the modeling (3). While we adopted a fairly simple formulation, there may be another modeling that solves Problem 1 in a better way. Another freedom is on the choice of a metric to measure the proximity between model P and model I. We used KL divergence because we could derive an EM algorithm to learn model I reasonably. The choice of the number of regions also remains as an open issue. Currently, we treat as a user tuning parameter, while, ultimately, it is desirable that the value of is determined automatically.
- Beal (2003) Beal, M. J. Variational algorithms for approximate bayesian inference. PhD thesis, Gatsby Computational Neuroscience Unit, University College London, 2003.
- Breiman (2001) Breiman, L. Random forests. Machine learning, 45(1):5–32, 2001.
- Breiman et al. (1984) Breiman, L., Friedman, J., Stone, C. J., and Olshen, R. A. Classification and Regression Trees. CRC press, 1984.
- Chen & Guestrin (2016) Chen, T. and Guestrin, C. Xgboost: A scalable tree boosting system. arXiv preprint arXiv:1603.02754, 2016.
- Chipman et al. (2003) Chipman, H. A., George, E. I., and Mcculloch, R. E. Bayesian treed generalized linear models. Bayesian Statistics, 7, 2003.
- Cui et al. (2015) Cui, Z., Chen, W., He, Y., and Chen, Y. Optimal action extraction for random forests and boosted trees. Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 179–188, 2015.
- Deng (2014) Deng, H. Interpreting tree ensembles with intrees. arXiv preprint arXiv:1408.5456, 2014.
- Eto et al. (2014) Eto, R., Fujimaki, R., Morinaga, S., and Tamano, H. Fully-automatic bayesian piecewise sparse linear models. Proceedings of the 17th International Conference on Artificial Intelligence and Statistics, pp. 238–246, 2014.
- Friedman (2001) Friedman, J. H. Greedy function approximation: A gradient boosting machine. Annals of statistics, 29(5):1189–1232, 2001.
- Jordan & Jacobs (1994) Jordan, M. I. and Jacobs, R. A. Hierarchical mixtures of experts and the em algorithm. Neural computation, 6(2):181–214, 1994.
- Mohan et al. (2011) Mohan, A., Chen, Z., and Weinberger, K. Q. Web-search ranking with initialized gradient boosted regression trees. Journal of Machine Learning Research, Workshop and Conference Proceedings, 14:77–89, 2011.
- Murthy et al. (1994) Murthy, S. K., Kasif, S., and Salzberg, S. A system for induction of oblique decision trees. Journal of Artificial Intelligence Research, 2:1–32, 1994.
- Tsanas & Xifara (2012) Tsanas, A. and Xifara, A. Accurate quantitative estimation of energy performance of residential buildings using statistical machine learning tools. Energy and Buildings, 49:560–567, 2012.
- Wang et al. (2015) Wang, J., Fujimaki, R., and Motohashi, Y. Trading interpretability for accuracy: Oblique treed sparse additive models. Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1245–1254, 2015.