Weighted Risk Minimization & Deep Learning
Importance weighting is a key ingredient in many algorithms for causal inference and related problems, such as off-policy evaluation for reinforcement learning. Recently, theorists proved that on separable data, unregularized linear networks, trained with cross-entropy loss and optimized by stochastic gradient descent converge in direction to the max margin solution. This solution depends on the location of data points but not their weights, nullifying the effect of importance weighting. This paper asks, for realistic deep networks, for which all datasets are separable, what is the effect of importance weighting? Lacking theoretical tools for analyzing modern deep (nonlinear, unregularized) networks, we investigate the question empirically on both realistic and synthetic data. Our results demonstrate that while importance weighting alters the learned model early in training, its effect diminishes to negligible with indefinite training. However, this diminishing effect does not occur in the presence of L2-regularization. These results (i) support the broader applicability of theoretical findings by Soudry et al (2018), who analyze linear networks; (ii) call into question the practice of importance weighting; and (iii) suggest that its usefulness interacts strongly with the early stopping criteria and regularization methods that interact with the loss function.
Importance sampling is a fundamental tool in statistics and machine learning that is used when we want to estimate some quantity on a particular target distribution, but can only sample from a different source distribution (Horvitz and Thompson, 1952; Kahn and Marshall, 1953; Rubinstein and Kroese, 2016; Koller et al., 2009). This situation arises in a wide variety of tasks including propensity score matching in causal inference (Rosenbaum and Rubin, 1983), active learning (Beygelzimer et al., 2009), domain adaptation (Lipton et al., 2018), and off-policy reinforcement learning (Precup, 2000). One commonly-used method to overcome this problem is weighting examples according to the likelihood ratio of the two distributions (Rubinstein and Kroese, 2016; Shimodaira, 2000; Koller et al., 2009), corresponding to the common technique of importance-weighted empirical risk minimization (IW-ERM).
In the typical classification setup,
when the data is separable,
and when optimizing cross-entropy loss,
there is no finite optimal solution. Addressing the simple case of linear networks and noting that only the direction of the weights
(but not their magnitude),
determines the separating hyperplane,
Soudry et al. (2017) ask the question of what if anything, the direction converges to. Surprisingly, they conclude that it converges to the solution of the hard-margin support vector machine.
An interesting ramification follows from this theoretical result in conjunction with importance weighting. The hard-margin solution depends only on the location of data points, and thus is unaffected by oversampling/re-weighting. While Soudry et al. (2017) and Gunasekar et al. (2018)’s analyses only address linear networks and separable data, we consider the ramifications were similar results to hold for typical modern deep (nonlinear) neural networks, for which all datasets of practical interest are separable (Zhang et al., 2017).
An interesting follow-up question to this consideration involves regularization. Common regularization methods like L2-regularization penalize the large-norm solutions that minimize cross-entropy. Intuitively, convergence to the hard-margin solution is a consequence of the loss function becoming dominated by the support vectors as the norm of the weights diverges towards infinity (Soudry et al., 2017). If L2-regularization prevents such large-norm behavior, what, if anything, do the weights converge to in this case? Additionally, dropout (Srivastava et al., 2014) is often thought of as a regularization method for deep networks, however, it does not penalize large-norm solutions. We hypothesize that these regularization methods will have different impacts the solution found by SGD on deep networks in conjuntion with importance weighting.
In this paper, we empirically investigate these questions regarding the effects of importance weighting on deep networks through experiments testing both linear and deep learning models on separable toy data as well as CIFAR-10 image classification. Our results show that i) Importance weighting seems not to have any meaningful affect on the final solutions found by deep networks trained with SGD. (ii) Importance weighting does affect solutions found earlier in the training process, suggesting that previous results from importance-weighted deep networks likely depended on early stopping. (iii) Importance weighting significantly affects model decision boundaries when combined with l2-regularization. However, this effect is absent when dropout is used instead.
2 Related Work
Importance sampling allows one to estimate some properties of a target distribution , when only a source distribution , is available to be sampled from. (Horvitz and Thompson, 1952; Kahn and Marshall, 1953; Rubinstein and Kroese, 2016; Koller et al., 2009). Concretely, given samples , and the task of estimating some function of the data, say , under the target distribution , importance sampling produces an unbiased estimate by weighting each sample according to the likelihood ratio :
Machine learning practitioners commonly exploit this idea in two ways: (i) by re-sampling to correct for the discrepancy in likelihood or (ii) by weighting examples according to the likelihood ratio (Rubinstein and Kroese, 2016; Shimodaira, 2000; Koller et al., 2009). The latter approach underlies the common method of importance-weighted empirical risk minimization (IW-ERM).
When the likelihood ratio is known, importance sampling offers unbiased estimates of the generalization error (Cortes et al., 2008), and IW-ERM optimizes an unbiased estimate of the loss function on distribution . However, this unbiasedness can come at the cost of significant variance—when for some values of , , importance weights can grow arbitrarily large.
Algorithms employing IW-ERM pop up in many areas of machine learning. Typically, in causal inference we possess only offline data, but wish to answer counterfactual questions, estimating the statistical properties that might hold were we to intervene, mutating the probability distribution (Rosenbaum and Rubin, 1983). More recently, Johansson et al. (2018) employ weighted risk minimization in a study aimed at estimating causal effects using deep neural networks.
In off-policy reinforcement learning, we wish to evaluate a target policy , using only trajectories sampled from a source . In the bandit literature, where importance sampling is sometimes known by the alias inverse propensity scoring (Horvitz and Thompson, 1952; Dudík et al., 2011; Agarwal et al., 2014), the technique is especially important, because it can be prohibitively expensive to collect a fresh set of data (by interacting with the environment) everytime we wish to evaluate a proposed policy. Importance sampling is also found in the literature on temporal difference learning in Markov decision processes (Precup, 2000; Precup et al., 2001). More recently, various applications of importance sampling have emerged in the modern literature on learning contextual bandits with deep networks (Swaminathan and Joachims, 2015), deep reinforcement learning (Pinto and Gupta, 2016), and deep imitation learning (Murali et al., 2016). Another deep reinforcement learning paper uses weighted sampling to choose experiences from the replay buffer for performing TD updates (Schaul et al., 2016). However, these weights are not based on likelihood ratios, but instead are heuristically chosen to be proportional to the Bellman errors for those experiences.
Weighted ERM is also used in active learning, where one hopes to correct for deliberate sample selection bias (Beygelzimer et al., 2009), and in domain adaptation, to learn a classifier on data sampled from the source distribution that minimizes error on a target distribution . While domain adaptation is impossible in general, under the covariate shift (sample selection bias) assumption (Shimodaira, 2000) or label shift (also called prior probability shift) (Elkan, 2001), models can be improved by incorporating unlabeled target data, often through IW-ERM. Recently, Lipton et al. (2018) introduced a consistent estimator for the test set label distribution under mild conditions, demonstrating that deep neural networks can be improved through retraining with accurately estimated importance weights.
Recently, Soudry et al. (2017) investigated explanations for the generalization abilities of neural networks. Problematically, given a fixed neural network architecture, there exist infinitely many sets of parameters that achieve empirical risk arbitrarily close to , most of which generalize poorly to unseen data. Thus determining precisely why neural networks generalize requires understanding the behavior of the optimization algorithm, typically a variant of stochastic gradient descent (SGD). When optimizing square loss and initialized at the origin, SGD is known to converge to the minimum Euclidean norm solution. However, when optimizing cross-entropy loss on a classification task, there is no finite optimal solution when the data is separable. This is easy to see in the linear case (logistic regression), where the loss can always be driven further down by scaling up the weights, . Here is the loss, is the sigmoid activation function, and is the value of the weights at iteration of SGD. If only the direction of the weights but not their magnitude, determines the partitioning of the data into classes, then what does the direction of converge to? Soudry et al. (2017) show that for linear models it converges to the solution of the hard-margin support vector machine (SVM). These results hold on linear networks and were subsequently extended to deep fully-connected linear networks in Gunasekar et al. (2018), which also analyzes linear convolutional network, showing that for these nets, the solution related to the bridge penalty in the frequency domain.
We investigate the effects of importance weighting on neural networks on both a linearly separable 2-dimensional toy dataset and the CIFAR10 image dataset. Our experiments address the label-shift scenario, weighting examples based on their class. Specifically, we down-weight the loss contributions of examples from a particular class. We also test the combination of regularization and importance weighting on the toy dataset. For L2-regularization, we set the penalty coefficient as , and when using dropout on deep networks, we set the value of a hidden unit to during training with probability .
In order to visualize model decision boundaries, we conduct an experiment with a synthetic 2-dimensional linearly separable dataset. examples were sampled from a 2d truncated normal distribution to form a positive class. The set of negative examples were generated by rotating and translating the positive set as shown in Figure 1. We train both a logistic regression model (without regularization) and a multi-layer perceptron using minibatch SGD for 10,000 epochs with a batch size of . The MLP has a single hidden layer of hidden units with ReLu activations. Both models use a fixed learning rate of , where is the maximum singular value of the data matrix. This learning rate was chosen to match the experiments of Soudry et al. (2017), and took a value of on our dataset. Results are shown in Figures 1, 2, and 3.
We also conduct experiments on the CIFAR10 dataset (see results in Figure 4). Here, we train a binary classifier on training images labeled as cats or dogs ( per class), evaluating on all test images from all classes as well as random noise images. The classifier is a convolutional network with the following structure: two convolution layers with filters each and stride , followed by a max pooling layer, followed by three convolution layers with filters each and stride , followed by a second max pooling layer, followed by two dense layers with and hidden units respectively, then the binary output layer. All hidden layer activation functions are ReLu functions. The model is trained for epochs using minibatch SGD with a batch size of , constant learning rate of , and no momentum. Experiments were run with importance weights of inverse powers of for each class, as well as with no importance weighting. All successfully trained models achieved test accuracies between and .
Our results show that as training progresses, the effects due to importance weighting vanish (Figures 1 and 4). While weighting impacts a model’s decision boundary early in training, in the limit, models with widely-varying weights appear to approach identical solutions. After many epochs of training, there is no clear relationship between the class-based importance weights and the classification ratios on either test set images, out-of-domain images, or random vectors (Figure 3(d)). However, models with more extreme weighting converge more slowly (to the same decision boundary). We note that these findings only hold up to class ratios of 256:1; at 512:1 models destabilize, classifying all examples as the up-weighted class (Figure 4).
We also show that importance weighting does have an effect on L2-regularized models (Figure 2). Both logistic regression and the neural network partition more of the sample space to the non-down-weighted class. However, this effect is absent when L2-regularization is replaced with dropout (Figure 3). In this case, the model behaves similar to before.
Our experiments suggest that effects from importance weighting on deep networks may only occur in conjunction with early stopping, disappearing asymptotically. For these vastly under-determined models capable of fitting any training set, the final solution may be determined solely by the location of training examples, independent of their density. This aligns with the theoretical results from Soudry et al. (2017), and causes us to question the role of importance weighting for deep networks.
Importance weighting does seem to affect models when combined with L2-regularization, however. We believe that in this case, L2-penalty prevents SGD from approaching the large norm solutions whose loss is dominated by the support vectors, thus preventing convergence to the hard margin solution. In this case, the importance weights are allowed to meaningfully affect the loss function. This aligns with our finding that dropout does not similarly affect the decision boundary in the limit.
As previously noted, importance weighting been empirically shown to be useful for deep networks by several others (Lipton et al., 2018; Schaul et al., 2016; Burda et al., 2015). If importance weighting is only useful for deep networks in combination with early stopping, then is there a principled way to choose stopping times when importance weighting is desired? In future work, we will investigate whether these same patterns emerge with learning rate scheduling or adaptive learning rates (like AdaGrad, RMSprop, etc), and compare loss-function weighting to over/under sampling, continuing to look ahead through empirical inquiry at the ramifications of developments in the theory of deep learning.
- Alekh Agarwal, Daniel Hsu, Satyen Kale, John Langford, Lihong Li, and Robert Schapire. Taming the monster: A fast and simple algorithm for contextual bandits. In International Conference on Machine Learning (ICML), 2014.
- Alina Beygelzimer, Sanjoy Dasgupta, and John Langford. Importance weighted active learning. In International Conference on Machine Learning (ICML), 2009.
- Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. In International Conference on Learning Representations (ICLR), 2015.
- Corinna Cortes, Mehryar Mohri, Michael Riley, and Afshin Rostamizadeh. Sample selection bias correction theory. In International Conference on Algorithmic Learning Theory (ALT), 2008.
- Miroslav Dudík, John Langford, and Lihong Li. Doubly robust policy evaluation and learning. In International Conference on Machine Learning (ICML), 2011.
- Charles Elkan. The foundations of cost-sensitive learning. In International joint conference on artificial intelligence (IJCAI), 2001.
- Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Implicit bias of gradient descent on linear convolutional networks. arXiv preprint arXiv:1806.00468, 2018.
- Daniel G Horvitz and Donovan J Thompson. A generalization of sampling without replacement from a finite universe. Journal of the American statistical Association (JASA), 1952.
- Fredrik D Johansson, Nathan Kallus, Uri Shalit, and David Sontag. Learning weighted representations for generalization across designs. arXiv preprint arXiv:1802.08598, 2018.
- Herman Kahn and Andy W Marshall. Methods of reducing sample size in monte carlo computations. Journal of the Operations Research Society of America, 1953.
- Daphne Koller, Nir Friedman, and Francis Bach. Probabilistic graphical models: principles and techniques. MIT press, 2009.
- Zachary C Lipton, Yu-Xiang Wang, and Alex Smola. Detecting and correcting for label shift with black box predictors. In International Conference on Machine Learning (ICML), 2018.
- Adithyavairavan Murali, Animesh Garg, Sanjay Krishnan, Florian T Pokorny, Pieter Abbeel, Trevor Darrell, and Ken Goldberg. Tsc-dl: Unsupervised trajectory segmentation of multi-modal surgical demonstrations with deep learning. In Robotics and Automation (ICRA), 2016.
- Lerrel Pinto and Abhinav Gupta. Supersizing self-supervision: Learning to grasp from 50k tries and 700 robot hours. In Robotics and Automation (ICRA), 2016.
- Doina Precup. Eligibility traces for off-policy policy evaluation. Computer Science Department Faculty Publication Series, 2000.
- Doina Precup, Richard S Sutton, and Sanjoy Dasgupta. Off-policy temporal-difference learning with function approximation. In International Conference on Machine Learning (ICML), 2001.
- Paul R Rosenbaum and Donald B Rubin. The central role of the propensity score in observational studies for causal effects. Biometrika, 1983.
- Reuven Y Rubinstein and Dirk P Kroese. Simulation and the Monte Carlo method. John Wiley & Sons, 2016.
- Tom Schaul, John Quan, Ioannis Antonoglou, and David Silver. Prioritized experience replay. In International Conference on Learning Representations (ICLR), 2016.
- Hidetoshi Shimodaira. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of statistical planning and inference, 2000.
- Daniel Soudry, Elad Hoffer, and Nathan Srebro. The implicit bias of gradient descent on separable data. In Inernational Conference on Learning Representations (ICLR), 2017.
- Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 2014.
- Adith Swaminathan and Thorsten Joachims. Counterfactual risk minimization: Learning from logged bandit feedback. In International Conference on Machine Learning (ICML), 2015.
- Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. In International Conference on Learning Representations (ICLR), 2017.