Why resampling outperforms reweighting for correcting sampling bias
A data set sampled from a certain population is biased if the subgroups of the population are sampled at proportions that are significantly different from their underlying proportions. Training machine learning models on biased data sets requires correction techniques to compensate for potential biases. We consider two commonly-used techniques, resampling and reweighting, that rebalance the proportions of the subgroups to maintain the desired objective function. Though statistically equivalent, it has been observed that reweighting outperforms resampling when combined with stochastic gradient algorithms. By analyzing illustrative examples, we explain the reason behind this phenomenon using tools from dynamical stability and stochastic asymptotics. We also present experiments from regression, classification, and off-policy prediction to demonstrate that this is a general phenomenon. We argue that it is imperative to consider the objective function design and the optimization algorithm together while addressing the sampling bias.
A data set sampled from a certain population is called biased if the subgroups of the population are sampled at proportions that are significantly different from their underlying population proportions. Applying machine learning algorithms naively to biased training data can raise serious concerns and lead to controversial results (Sweeney, 2013; Kay et al., 2015; Menon et al., 2020). In many domains such as demographic surveys, fraud detection, identification of rare diseases, and natural disasters prediction, a model trained from biased data tends to favor oversampled subgroups by achieving high accuracy there while sacrificing the performance on undersampled subgroups. Although one can improve by diversifying and balancing during the data collection process, it is often hard or impossible to eliminate the sampling bias due to historical and operational issues.
In order to mitigate the biases and discriminations against the undersampled subgroups, a common technique is to preprocess the data set by compensating the mismatch between population proportion and the sampling proportion. Among various approaches, two commonly-used choices are reweighting and resampling. In reweighting, one multiplies each sample with a ratio equal to its population proportion over its sampling proportion. In resampling, on the other hand, one corrects the proportion mismatch by either generates new samples for the undersampled subgroups or selecting a subset of samples for the oversampled subgroups. Both methods result in statistically equivalent models in terms of the loss function (see details in Section 2). However, it has been observed in practice that resampling often outperforms reweighting significantly, such as boosting algorithms in classification (Galar et al., 2011; Seiffert et al., 2008), off-policy prediction in reinforcement learning (Schlegel et al., 2019) and so on. The obvious question is why.
Our main contribution is to provide an answer to this question: resampling outperforms reweighting because of the stochastic gradient-type algorithms used for training. To the best of our knowledge, our explanation is the first theoretical quantitative analysis for such this phenomenon. With stochastic gradient descent (SGD) being the dominant method for model training, our analysis is based on some recent developments for understanding SGD. We show via simple and explicitly analyzable examples why resampling generates expected results while reweighting performs undesirably. Our theoretical analysis is based on two points of view, one from the dynamical stability perspective and the other from stochastic asymptotics.
In addition to the theoretical analysis, we present experimental examples from three distinct categories (classification, regression, and off-policy prediction) to demonstrate that resampling outperforms reweighting in practice. This empirical study illustrates that this is a quite general phenomenon when models are trained using stochastic gradient type algorithms.
Our theoretical analysis and experiments show clearly that adjusting only the loss functions is not sufficient for fixing the biased data problem. The output can be disastrous if one overlooks the optimization algorithm used in the training. In fact, recent understanding has shown that objective function design and optimization algorithm are closely related, for example optimization algorithms such as SGD play a key role in the generalizability of deep neural networks. Therefore in order to address the biased data issue, we advocate for considering data, model, and optimization as an integrated system.
In a broader scope, resampling and reweighting can be considered as instances of preprocessing the training data to tackle biases of machine learning algorithms. Though there are many well-developed resampling (Mani and Zhang, 2003; He and Garcia, 2009; Maciejewski and Stefanowski, 2011) and reweighting (Kumar et al., 2010; Malisiewicz et al., 2011; Chang et al., 2017) techniques, we only focus on the reweighting approaches that do not change the optimization problem. It has been well-known that training algorithms using disparate data can lead to algorithmic discrimination (Bolukbasi et al., 2016; Caliskan et al., 2017), and over the years there have been growing efforts to mitigate such biases, for example see (Amini et al., 2019; Kamiran and Calders, 2012; Calmon et al., 2017; Zhao et al., 2019; López et al., 2013). We also refer to (Haixiang et al., 2017; He and Ma, 2013; Krawczyk, 2016) for a comprehensive review of this growing research field.
Our approaches for understanding the dynamics of resampling and reweighting under SGD are based on tools from numerical analysis for stochastic systems. Connections between numerical analysis and stochastic algorithms have been rapidly developing in recent years. The dynamical stability perspective has been used in (Wu et al., 2018) to show the impact of learning rate and batch size in minima selection. The stochastic differential equations (SDE) approach for approximating stochastic optimization methods can be traced in the line of work (Li et al., 2017, 2019; Rotskoff and Vanden-Eijnden, 2018; Shi et al., 2019), just to mention a few.
2 Problem setup
Let us consider a population that is comprised of two different groups, where a proportion of the population belongs to the first group, and the rest with the proportion belongs to the second (i.e., and ). In what follows, we shall call and the population proportions. Consider an optimization problem for this population over a parameter . For simplicity, we assume that each individual from the first group experiences a loss function , while each individual from the second group has a loss function of type . Here the loss function is assumed to be identical across all members of the first group and the same for across the second group, however it is possible to extend the formulation to allow for loss function variation within each group. Based on this setup, a minimization problem over the whole population is to find
For a given set of individuals sampled uniformly from the population, the empirical minimization problem is
where denotes which group an individual belongs to. When grows, the empirical loss in (2) is consistent with the population loss in (1) as there are approximately fraction of samples from the first group and fraction of samples from the second.
However, the sampling can be far from uniformly random in reality. Let and with denote the number of samples from the first and the second group, respectively. It is convenient to define as the sampling proportions for each group, i.e., and with . The data set is biased when the sampling proportions and are different from the population proportions and . In such a case, the empirical loss is , which is clearly wrong when compared with (1).
Let us consider two basic strategies to adjust the model: reweighting and resampling. In reweighting, one assigns to each sample a weight and the reweighting loss function is
In resampling, one either adds samples to the minority group (i.e., oversampling) or removing samples from the majority group (i.e., undersampling). Although the actual implementation of oversampling and undersampling could be quite sophisticated in order to avoid overfitting or loss of information, mathematically we interpret the resampling as constructing a new set of samples of size , among which samples are of the first group and samples of the second. The resampling loss function is
Notice that both and are consistent with the population loss function . This means that, under mild conditions on and , a deterministic gradient descent algorithm from a generic initial condition converges to similar solutions for and . However, as we shall explain below, the behavior can be drastically different when a stochastic gradient algorithm is used. The key reason is that the variances experienced can be very different.
3 Stability analysis
Let us use a simple example to illustrate why resampling outperforms reweighting under SGD, from the viewpoint of stability. Consider two loss functions and with disjoint supports,
each of which is quadratic on its support. The population loss function is , with two local minima at and . The gradients for and are
Suppose that the population proportions satisfy , then is the global minimum and it is desired that SGD should be stable near it. However, as shown in Figure 1, when the sampling proportion is significantly less than the population proportion , for reweighting can easily become unstable: even if one starts near the global minimum , the trajectories for reweighting always gear towards after a few steps (see Figure 1(1)). On the other hand, for resampling is quite stable (see Figure 1(2)).
|(1) Reweighting||(2) Resampling|
The expectations of the stochastic gradient are the same for both methods. It is the difference in the second moment that explains why trajectories near the two minima exhibit different behaviors. Our explanation is based on the stability analysis framework used in (Wu et al., 2018). By definition, a stationary point is stochastically stable if there exists a uniform constant such that , where is the -th iterate of SGD. The stability conditions for resampling and reweighting are stated in the following two lemmas.
For resampling, the conditions for the SGD to be stochastically stable around and are respectively
For reweighting, the condition for the SGD to be stochastically stable around and are respectively
Note that the stability conditions for resampling are independent of the sampling proportions , while the ones for reweighting clearly depend on . We refer the detailed computations to Appendix A.
Lemma 2 shows that reweighting makes problems stiffer in terms of the stability criterion. Let us consider the case with a small constant and . For reweighting, the global minimum is stochastically stable only if . This condition becomes rather stringent in terms of the learning rate since . On the other hand, the local minimizer is stable if , which could be satisfied for a broader range of because . In other words, for a fixed learning rate , when the ratio between the sampling proportions is sufficiently small, the desired minimizer is no longer statistically stable with respect to SGD.
4 SDE analysis
The stability analysis can only be carried for a learning rate of a finite size. However, even for a small learning rate , one can show that the reweighting method is still unreliable from a different perspective. This section applies stochastic differential equation analysis to demonstrate it.
Let us again use a simple example to illustrate the main idea. Consider the following two loss functions,
with . The population loss function is with local minimizers and . Note that the terms are necessary. Without it, if the SGD starts in , all iterates will stay in this region because there is no drift from . Similarly, if the SGD starts in , no iterates will move to . That means the result of SGD only depends on the initialization when term is absent.
In Figure 2, we present numerical simulations of the resampling and reweighting methods for the designed loss function . If , then the global minimizer of is (see the Figure 2(1)). Consider a setup with population proportions along sampling proportions , which are quite different. Figures 2(2) and (3) show the dynamics under the reweighting and resampling methods, respectively. The plots show that, while the trajectory for resampling is stable across all the time, the trajectory for reweighting quickly escapes to the (non-global) local minimizer even when it starts near the global minimizer .
|(1) Loss function||(2) Reweighting||(3) Resampling|
When the learning rate is sufficiently small, one can approximate the SGD by an SDE, which in this piece-wise linear loss example is approximately a Langevin dynamics with a piecewise constant mobility. In particular when the dynamics reaches equilibrium, the stationary distribution of the stochastic process is approximated by a Gibbs distribution, which gives the probability densities at the stationary points. Let us denote and as the stationary distribution over under resampling and reweighting, respectively. Suppose that and hence . The following lemmas quantitatively summarize the results.
When , . The stationary distribution for resampling satisfies the relationship
With , . Under the condition for the sampling proportions, the stationary distribution for reweighting satisfies the relationship
The proofs of the above two lemmas can be found in Appendix B. Lemma 3 shows that for resampling it is always more likely to find at the global minimizer than at the local minimizer . Lemma 4 states that for reweighting it is more likely to find at the local minimizer when . Together, they explain the phenomenon shown in Figure 2.
To better understand the condition in Lemma 4, let us consider the case with a small constant . Under this setup, . Whenever the ratio of the sampling proportions is significantly less than the ratio of the population proportions , reweighting will lead to the undesired behavior. The smaller the ratio is, the less likely the global minimizer will be visited.
Let us now consider the minimization of for more general and also in high dimensions. It is in fact not clear how to extend the above stochastic analysis to more general functions . Instead we focus on the transition time from one stationary point to another in order to understand the behavior of resampling and reweighting. For this purpose, we again resort to the SDE approximation of the SGD in the continuous time limit.
Such a SDE approximation, first introduced in (Li et al., 2017), involves a data-dependent covariance coefficient for the diffusion term and is justified in the weak sense with an error of order . More specifically, the dynamics can be approximated by
where for the step parameter , is the learning rate, and is the covariance of the stochastic gradient at location . In the SDE theory, the drift term is usually assumed to be Lipschitz. However, in machine learning (for example neural network training with non-smooth activation functions), it is common to encounter not-Lipschitz gradients of loss functions (as in the example presented in Section 3). To fill this gap, we provide in Appendix C a justification of SDE approximation for the drift with jump discontinuities, based on the proof presented in (Müller-Gronbach and Yaroslavtseva, 2020). The following two lemmas summarize the transition times between the two local minimizers.
Assume that there are only two local minimizers for the objective function . Let be the transition time for in (6) from the -neighborhood of (a closed ball of radius centered at ) to the -neighborhood of and be the transition time in the opposite direction. Then
Here and are the determinants of the Hessians at and
, respectively. for where is the
saddle point between .
This lemma is known in the diffusion process literature as the Eyring-Kramers formula; see, e.g., (Berglund, 2011; Bovier et al., 2004, 2005). Using the above lemma, we obtain the following result for the transition times for resampling and reweighting.
Assume that there are only two local minimizers for the objective function . Also assume that the loss function for the first group is in the -neighborhood of and the loss function for the second group is in the -neighborhood of . In addition, assume that the determinants of the Hessian at two local minimizers are the same. Then the ratio of the transition times between the two local minimizers for resampling is
and the ratio for reweighting is
See Appendix B for the proof. When the ratio is larger than , it means that is more stable than . This result shows that for reweighting the relative stability of the two minimizers highly depends on the sampling proportions . On the other hand, for resampling it is independent of . To see how the sampling proportions affect the behavior of reweighting, let us consider a simple case where is the global minimizer, , , and . This ensures that and the above ratio for resampling is larger than , which is the desired result. However, implies that , , and the above ratio for reweighting is much smaller than , which means that the local minimizer is more stable than the global minimizer .
This section examines the empirical performance of resampling and reweighting for problems from classification, regression, and reinforcement learning.
This experiment uses the Porto Seguroâs safe driver prediction data set
The code from the imbalanced-learn
To estimate the performance, rather than using the classification accuracy that can be misleading for biased data, we use the metric that computes the area under the receiver operating characteristic curve (ROC-AUC) from the prediction scores. The ROC curves plots the true positive rate on the -axis versus the false positive rate on the -axis. As a result, a larger area under the curve indicates a better performance of a classifier. As shown in Figure 3, resampling outperforms reweighting in terms of the ROC-AUC scores for different sampling proportions .
Using the Labeled Faces in the Wild (LFW) dataset, we compare the performance of resampling and reweighting for the facial recognition problem. Two celebrities are picked from the LFW dataset: the majority group among selected samples consists of photos of President George W. Bush, while the minority group contains photos of President Bill Clinton. The sampling proportions are and the underlying population proportions are .
We again adapt the code from the imbalanced-learn package and experiment on several oversampling methods (ADASYN, ROS, SMOTE) and undersampling methods (RUS, NearMiss), as well as reweighting, by conducting SGDClassifier using logistic regression. Figure 4 plots the ROC curves for all these methods mentioned above. We emphasize that this is not a claim that SGDClassifier gives the best prediction score. Instead, our motivation is to illustrate that the performance of SGD deteriorates when reweighting is used.
In the off-policy prediction problem in reinforcement learning, the objective is to find the value function of police using the trajectory generated by a behavior policy . To achieve this, the standard approach is to update the value function based on the behavior policy’s temporal difference (TD) error with an importance weight , where the summation is taken over the action space . The resulting reweighting TD learning for policy is
where is the learning rate. This update rule is an example of reweighting. On the other hand, the expected TD error can also be written in the resampling form, , where is the total number of samples for . This results to a resampling TD learning algorithm: at step ,
where is randomly chosen from the data set with probability .
Consider a simple example with discrete state space , action space , and transition dynamics , where the operator gives the remainder of divided by . Figure 5 shows the results of the off-policy TD learning by these two approaches, with the choice of and . The target policy is while the behavior policy is . The difference between the two policies becomes larger as the constant increases. From the previous analysis, if one group has much fewer samples as it should have, then the minimizer of the reweighting method is highly affected by the sampling bias. This is verified in the plots: as becomes larger, the performance of reweighting deteriorates, while resampling is rather stable and almost experiences no difference with the on-policy prediction in this example.
This paper examines the different behaviors of reweighting and resampling for training on biasedly sampled data with the stochastic gradient descent. From both the dynamical stability and stochastic asymptotics viewpoints, we explain why resampling is numerically more stable and robust than reweighting. Based on this theoretical understanding, we advocate for considering data, model, and optimization as an integrated system, while addressing the bias.
An immediate direction for future work is to apply the analysis to more sophisticated stochastic training algorithms and understand their impact on resampling and reweighting. Another direction is to extend our analysis to unsupervised learning problems.
Appendix A Proofs in section 3
a.1 Proof of Lemma 1
In resampling, near the gradient is with probability and with probability . Let us denote the random gradient at each step by , where is a Bernoulli random variable with mean and variance . At the learning rate , the iteration can be written as
The first and second moments of the iterates are
According to the definition of the stochastic stability, SGD is stable around if the multiplicative factor of the second equation is bounded by , i.e.
Consider now the stability around , the iteration can be written as
where is again a Bernoulli random variable with and . The same computation shows that the second moment follows
Therefore, the condition for the SGD to be stable around is
a.2 Proof of Lemma 2
In reweighting, near the gradient is with probability and with probability . Let us denote the random gradient at each step by , where is a Bernoulli random variable with and . At the learning rate , the iteration can be written as
Hence the second moments of the iterates are given by
Therefore, the condition for the SGD to be stable around is
Consider now the stability around , the gradient is with probability and with probability . An analysis similar to the case shows that the condition for the SGD to be stable around is
Appendix B Proofs in section 4
b.1 Proof of Lemma 3
In resampling, with probability the gradients over the four intervals , , , and are , , , and . With probability , they are , , , and across these four intervals. The variances of the gradients are , , , , respectively, across the same intervals.
Since , the variance can be written as across all intervals. Then the SGD dynamics with learning rate can be approximated by
where is a normal random variable. When is small, one can approximate the dynamics by a stochastic differential equation of form
by identifying (see Appendix C for details). The stationary distribution of this stochastic process is
Plugging in results in
Under the assumption that , the last term is negligible. When , is minimized at , which implies . Hence, this ratio is larger than 1. ∎
b.2 Proof of Lemma 4
In reweighting, with probability the gradients are , , , and over the four intervals , , , and , respectively. With probability , they are , , , and . The variances of the gradients are , , , and , respectively, across the same intervals.
Since , the variance can be written as for and for .
With , the approximate SDE for is given by
while the one for is
(see Appendix C for the SDE derivations). The stationary distributions for and are, respectively,
Plugging in results in
The next step is to figure out the relationship between and . Consider an SDE with non-smooth diffusion . The Kolmogorov equation for the stationary distribution is
This suggests that is continuous at the discontinuity . In our setting, since , this simplifies to
This simplifies to
Inserting this into (10) results in
By the assumption and , one has and . Hence the above ratio is less than . ∎
b.3 Proof of Lemma 6
The variances of gradients for resampling and reweighting are respectively,