Evaluating Model Robustness to Dataset Shift

Evaluating Model Robustness to Dataset Shift

Abstract

As the use of machine learning in safety-critical domains becomes widespread, the importance of evaluating their safety has increased. An important aspect of this is evaluating how robust a model is to changes in setting or population, which typically requires applying the model to multiple, independent datasets. Since the cost of collecting such datasets is often prohibitive, in this paper, we propose a framework for evaluating this type of robustness using a single, fixed evaluation dataset. We use the original evaluation data to define an uncertainty set of possible evaluation distributions and estimate the algorithm’s performance on the “worst-case” distribution within this set. Specifically, we consider distribution shifts defined by conditional distributions, allowing some distributions to shift while keeping other portions of the data distribution fixed. This results in finer-grained control over the considered shifts and more plausible worst-case distributions than previous approaches based on covariate shifts. To address the challenges associated with estimation in complex, high-dimensional distributions, we derive a “debiased” estimator which maintains -consistency even when machine learning methods with slower convergence rates are used to estimate the nuisance parameters. In experiments on a real medical risk prediction task, we show that this estimator can be used to evaluate robustness and accounts for realistic shifts that cannot be expressed as covariate shift. The proposed framework provides a means for practitioners to proactively evaluate the safety of their models using a single validation dataset.

1 Introduction

The environments in which we deploy machine learning (ML) algorithms rarely look exactly like the environments in which we collected our training data. Unfortunately, we lack methodology for evaluating how well an algorithm will generalize to new environments that differ in a structured way from the training data (i.e., the case of dataset shift (Quiñonero-Candela et al., 2009)). Such methodology is increasingly important as ML systems are being deployed across a number of industries, such as health care and personal finance, in which system performance translates directly to real-world outcomes. Further, as regulation and product reviews become more common across industries, system developers will be expected to produce evidence of the validity and safety of their systems. For example, the United States Food and Drug Administration (FDA) currently regulates ML systems for medical applications, requiring evidence for the validity of such systems before approval is granted (US Food and Drug Administration, 2019).

Evaluation methods for assessing model validity have typically focused on how the model performs on data from the training distribution, known as internal validity. Powerful tools, such as cross-validation and the bootstrap, satisfy the assumption that the training and test data are drawn from the same distribution. However, these validation methods do not capture a model’s ability to generalize to new environments, known as external validity (Campbell and Stanley, 1963). Currently, the main way to assess a model’s external validity is to empirically evaluate performance on multiple, independently collected datasets (e.g., as commonly done in healthcare (Subbaswamy and Saria, 2020)); however, this approach has practical limitations. If we cannot completely characterize how the datasets differ, or if the datasets are not sufficiently diverse, then this evaluation provides only weak evidence of external validity. To counteract this, one could consider targeted collection of additional datasets which differ in structured ways (e.g., collecting data with differing demographics). This approach can be prohibitively costly, or in some cases, impossible. For example, we can not ethically collect new loan approval datasets in which we forcibly vary customer spending habits. Thus, we need a feasible alternative for assessing external validity without relying on external data.

One way to tackle this is through the lens of distributionally robust optimization (DRO) (Ben-Tal et al., 2013; Duchi et al., 2016). Instead of examining a model’s performance only on the (empirical) test distribution associated with a particular validation dataset, DRO defines an uncertainty set of possible test distributions and considers the model’s performance on the worst-case distribution chosen from this set. A model that performs well on this worst-case distribution is considered robust to dataset shift. This worst-case formulation is particularly sensible in safety-critical domains in which the cost of failure is high. For example, consider evaluating a model trained to diagnose a disease from a set of covariates (e.g., age, medical history, treatments). Here, we may wish to evaluate how the performance of a model would change if it were deployed in another hospital or if clinicians begin ordering different tests. These different changes in environment correspond to different types of distributional shifts. In order to build evidence that the model is safe to deploy in a variety of environments, it is necessary to evaluate the worst-case performance under a corresponding variety of shifts.

Figure 1: Age distribution of the worst-case populations resulting from no shift (Unshifted), shifts in the marginal distribution of features (Marginal), and shifts in the way tests are ordered (Conditional) for shift magnitude . The age distribution changes substantially under covariate shift, but remains close to the unshifted distribution under the more targeted evaluation.

For this reason, it is important that we have a framework for specifying how the data distribution can change (i.e., defining the uncertainty set) that is flexible enough to reflect targeted changes in environment. Despite this, much of the work on learning robust models considers shifts in the entire joint distribution (e.g., Ben-Tal et al. (2013); Duchi et al. (2016)) or shifts in the covariate distribution (e.g., Duchi et al. (2020); Chen et al. (2016); Wang et al. (2015); Liu and Ziebart (2014, 2017)). These approaches restrict the potential shifts by constraining the magnitude of the shift or particular moments of the shifted distribution; however, these approaches cannot express finer-grained shifts which, for example, let us isolate changes in the characteristics of a population from decisions made based on those characteristics.

For example, suppose we were interested in evaluating the robustness of the diagnosis model’s performance to changes in the way clinicians order tests, while keeping the underlying patient population fixed. If the model’s predictions depend heavily on the results of a particular test and it is deployed to a hospital where that test is not common, the model could become unsafe to use. In order to evaluate the model under this type of change, we need a framework that allows us to specify shifts in the distribution of test orders . Previous approaches, which allow only for shifts in the covariate distribution, may result in worst-case distributions that do not reflect the type of shift we are interested in. As an example (Fig 1), we evaluated a real diagnostic model’s robustness to covariate shift which resulted in a worst-case age distribution (green) that differed substantially from the observed age distribution (blue). By using the more flexible framework described in this work, we were able to evaluate the model’s performance under changes in the distribution of test orders while keeping the age distribution fixed (orange).

In this paper, we develop a method for evaluating the robustness of models to dataset shift without requiring the collection of new datasets. We make the following contributions: First, by defining uncertainty sets using shifts in conditional distributions, we generalize previous DRO formulations which consider shifts in marginal or joint distributions (Section 2). This provides substantially more control over the types of shifts we can consider and allows us to evaluate a model’s robustness in more targeted and realistic scenarios. Second, we propose the first -consistent method for estimating the worst-case expected loss under these types of distributional shifts (Sections 3.1). Finally, on a medical risk prediction task, we demonstrate that this method can be used to evaluate the robustness of a model’s performance to distribution shifts. We further show that using a less general formulation of distributional shift can lead to incorrect conclusions about a model’s robustness and safety.

2 Methods

We are interested in evaluating a prediction algorithm which has been trained to predict a target variable from a set of covariates . As a running example, consider evaluating a model for diagnosing a disease (i.e., is a binary label for the presence of the disease) from a patient’s medical history. As in the case of a third party reviewer, we will assume that is fixed and that we are evaluating the performance of on a fixed test dataset drawn i.i.d. from some distribution . Classically, to evaluate , we would select a loss function and estimate the expected loss under test distribution as . However, in addition to the expected loss under , in many practical applications we would like to know how robust the model’s expected loss is to changes or differences in that we expect to see upon deployment, referred to as distribution shift. In this section, we describe how to estimate the expected loss of under shifts in the distribution . We proceed by formally defining distribution shifts, specifying our objective for evaluating model performance under shifts, and, finally, deriving an estimator for this performance.

2.1 Defining Distribution Shifts

To define general distribution shifts, we will partition the variables into three sets. Let be a set of variables whose marginal distribution should remain fixed, let be a set of variables whose distribution (given ) we allow to vary, and let be the remaining variables. This partition of the variables defines a factorization of into . Then, we consider how model performance changes when is replaced with a new distribution , while leaving and unchanged. Notably, this formulation generalizes other commonly studied instances of distribution shift. For example, if we let and , then this corresponds to a marginal (Duchi et al., 2020) or covariate shift (Shimodaira, 2000; Sugiyama et al., 2007).

Changing which variables are in and allows us to change the type of shift we are interested in. Returning to our diagnosis example, setting and to the set of patient demographic and history variables allows us to evaluate what would happen if the patient population were to change. On the other hand, we may also wish to know how our diagnostic algorithm will perform under changes in treatment policies employed by hospitals when the underlying patient population remains the same. For example, clinicians often order tests (e.g., blood tests) in order to inform diagnoses. By setting to the the binary indicator for a particular test order and to include patient information used by doctors to decide on a test order, we can evaluate how performance varies under changes in the way clinicians order this test.1

2.2 Quantifying Performance Under Shifts

To evaluate the robustness of to these types of distributional shifts, we estimate the worst-case expected loss under changes in . We now define a tractable objective for estimating this loss.

Formally, we define an uncertainty set or “ball” of possible shifted distributions using a statistical divergence and radius :

(1)

We are interested in the performance under the distribution that maximizes expected loss

(2)

where is the conditional expected loss given and . By construction, will never place positive weight on regions where does not. While we can change the rates at which events occur in our test data, we cannot evaluate performance in situations that we have never seen. For example, we cannot reliably estimate how a model will perform in an intensive care unit (ICU) using only data from an emergency department (ED) as there are events that can occur in an ICU that cannot occur in an ED.

Calculating requires calculating the expected loss under various distributions . As in previous work on DRO and domain adaptation, we rely on importance sampling, which allows us to estimate expectations under using samples from . This is done by reweighting samples from by the likelihood ratio . Specifically, the expected loss under can be rewritten as

(3)

If is quite different from , the variance of importance sampling can be high. This variance is naturally governed by , which controls how different can be from . In order to consider environments that look very different from , a large test dataset may be needed.

The construction of the uncertainty set and computation of worst-case loss w.r.t. this set can be done for different choices of (Duchi and Namkoong, 2018). We will perform our analysis for one choice of , the conditional value at risk (CVaR), which for distributions and over is defined as . This choice of makes computations involving the dual of Equation 2 tractable and allows for various statistical guarantees when estimating . Further, as described next, the worst-case distribution under CVaR has a convenient interpretation.

Having chosen , we now rewrite Equation 2 as

(4)
(5)

where . Rather than optimizing directly over as in Equation 2, we optimize over the likelihood ratio which is equal to . The interpretability of CVaR is now more apparent: the that maximizes this constrained objective is an indicator function which selects a subpopulation of weight from . The radius has been mapped to a subpopulation size such that larger shifts correspond to larger (i.e., smaller subpopulations). Under CVaR, determining the worst-case distribution shift is equivalent to the problem of finding a worst-case (i.e., high loss) subpopulation from the original dataset of a particular size. Thus, computing worst-case performance under distribution shift simply requires computing performance on this subpopulation.

To illustrate this process, consider a simple data generating mechanism: , . Samples from this distribution are shown in Fig 2a. Suppose that we want to evaluate the robustness of a pre-trained linear classifier (also shown in Fig 2a) under distribution shifts. Under a shift in the distribution (i.e., ), Equation 4 simply selects the fraction of points which yield the highest conditional loss (i.e., incorrectly classified points furthest from the decision boundary) as shown in Fig 2b. By contrast, under a shift in the distribution (Fig 2c), points producing high conditional loss are chosen subject to the additional constraint of keeping the marginal the same as in Fig 2a. This is a direct consequence of constraint 5 and is the primary difference between the shifts.

(a) Data distribution
(b) Shift in
(c) Shift in
Figure 2: (a) Samples from the data distribution described in Section 2.2 along with the decision boundary for a linear classifier fit to this data. (b) The worst-case distribution under a shift in . (c) The worst-case distribution under a shift in . In each figure, the marginal distributions and are shown on the right and top. The shift in (c) keeps the marginal roughly the same as in (a).

2.3 Estimating the Worst-Case Risk

Having defined the worst-case risk in Equation 4, we now turn to the problem of how to actually estimate this risk. While previous work on distributionally robust optimization has considered how to train models that minimize (upper bounds on) worst-case risk (e.g., Duchi and Namkoong (2018)), to the best of our knowledge the problem of accurately estimating the worst-case risk itself has not received much attention. Thus, in this section we derive a consistent estimator for to address this problem. In Section 3.1, we will show that this estimator has favorable bias and variance properties, even in settings with high-dimensional or continuous features.

Our estimator relies heavily on the dual formulation for Equation 4. Defining to be the conditional expected loss as before, and following Duchi et al. (2020) and Duchi and Namkoong (2018), the dual is given by

(6)

where is the ramp function and the function is given by

(7)

Note that the objective for is equivalent to the mean absolute deviation objective used in quantile regression. Indeed, is the conditional quantile function of the conditional expected loss . Thus, estimating given is exactly quantile regression.2 Estimating requires first estimating the two (potentially infinite dimensional) nuisance parameters and . The dual formulation for suggests a potential naive estimation procedure: (1) Estimate , (2) estimate by plugging into Equation 7, and (3) estimate by plugging and into Equation 6. When and are low-cardinality discrete variables and is a low-dimensional vector, this simple procedure is -consistent, as shown in Duchi et al. (2020) for the case of . In most practical scenarios, in which or will be continuous or high-dimensional, we will instead want to use flexible machine learning (ML)-based estimators (e.g., random forests or deep neural nets) for . However, Chernozhukov et al. (2018) showed how, due to the slow convergence rates and common practice of regularization in flexible ML methods, their use in plugin procedures like this can lead to substantial bias and poor overall convergence rates.

Input: Model , Dataset , and cross-validation folds and
for  do
        Estimate using data in
        Estimate according to Eq. 7 using and data in
        for  do
               Let
               Let
               Let
              
        end for
       
end for
Let
        
Result:
Algorithm 1 Estimator for

To avoid these issues, following Chernozhukov et al. (2018) and Jeong and Namkoong (2020), we propose a so-called “debiased” machine learning (DML) estimator for . Among other conditions we will prove in Section 3.1, this estimator is able to maintain -consistency without assuming convergence rates for the ML estimators used for and . The procedure, detailed in Algorithm 1, splits the data into folds, fits estimates and on each fold using ML, and then combines these estimates in a way that adjusts for the slow convergence rates of the estimators for and . Within the Algorithm, implements the indicator function that selects the worst-case subpopulation. Thus, given estimates and computed by ML methods, we estimate , the worst-case risk under a shift of magnitude , as

(8)

In the following section, we establish the correctness of this estimator by proving its convergence rate and central limit behavior. Then we will apply the estimator to evaluate the robustness of models on a real clinical diagnosis problem.

3 Results

In this section, we present four main results: First, we prove that the proposed estimation method has properties that allow it to reliably estimate the worst-case expected loss under distributional shift (Section 3.1). Second, in the context of a practical domain, we demonstrate that our method can be used to determine new settings in which prediction models will still be safe to use (Section 3.2.2). Third, we show that worst-case performance estimated by our method matches the performance in an actual new environment (Section 3.2.3). Finally, we highlight the need for a flexible framework for specifying shifts by empirically demonstrating that evaluating a model under different types of shifts can lead to substantively different conclusions (Section 3.2.4). The proposed method fills the gap created by the lack of tools for specifying and evaluating robustness to the many types of shifts that we can encounter in practice.

3.1 Theoretical Results

In order to reliably use our method to evaluate model robustness and safety, it is important that our estimator (i) converges to the true worst-case loss (consistency) and (ii) has low levels of bias and variance. We show this by proving that, despite our use of regularized ML methods, our estimator converges to the true worst-case loss at the same rate as if we had known the true nuisance parameter values (-consistency). Further, we show that the estimator is asymptotically normal, allowing for easy construction of valid confidence intervals. This represents the first such estimator for worst-case expected loss under distributional shift in non-trivial settings.

To guarantee -consistency and asymptotic normality, we make the following assumptions where , is the true conditional quantile function for (e.g., ), and, for parsimony, we drop the dependence of , , and on , , and .

Assumption 1.

Let and be sequences of positive constants approaching , let , , and , be fixed strictly positive constants such that , and let be a fixed integer. Also, let be the estimates of estimated on the ’th cross-validation fold. Then, we assume that (a) and with probability no less than : (b) , (c) , (d) , (e) , and (f) for and , there exists a positive density at almost everywhere.

Assumptions 1 (a) - (c) place a moment condition on the loss function and assume basic consistency of the nuisance parameter estimators. Assumption 1 (f) is a standard requirement which ensures that the conditional quantiles of converge to the conditional quantiles of (Jeong and Namkoong, 2020; Van der Vaart, 2000). Finally, Assumptions 1 (d) and (e), place convergence rate restrictions on the estimators for and , which are notably slower than the rate we desire. This admits the use of various ML estimators for and , such as ReLU neural networks.3 We can now guarantee -consistency and central limit properties of using a version of Theorem 3.1 from Chernozhukov et al. (2018):

Theorem 1.

Under Assumption 1, let be a sequence of positive integers converging to zero such that for all . Then we have that concentrates in a neighborhood of and is approximately linear and centered Gaussian:

(9)

where , ,

and .

An important consequence of this result is that we can further estimate the variance of as

(10)

and this estimate can be used to construct valid confidence intervals as (Chernozhukov et al., 2018). Note that the variance estimate scales with , highlighting that a large dataset may be needed to estimate for large . With a reliable estimator in hand, we now demonstrate its utility on a real clinical prediction problem.

3.2 Experimental Results

In the context of a practical domain, we now demonstrate that: (a) the proposed method can be used to evaluate the performance of models under meaningful distribution shifts and (b) that without the ability to evaluate multiple types of shifts, we may reach incorrect conclusions about model robustness. Suppose we are third-party reviewers, such as the FDA, who wish to evaluate the safety of a machine learning-based clinical diagnostic model to changes in clinician lab test ordering patterns. Our goal is to test the ability of the model to generalize to new hospitals which are likely to have different test ordering patterns.

Dataset

We demonstrate our approach on the task of evaluating machine learning models for diagnosing sepsis, a life-threatening response to infection. We follow the setup of Giannini et al. (2019) who developed a clinically validated sepsis diagnosis algorithm. Our dataset contains electronic health record data collected over four years at Hospital A. The dataset consists of 278,947 emergency department patient encounters. The prevalence of the target disease, sepsis, is . 17 features pertaining to vital signs, lab tests, and demographics were extracted. A full characterization of the dataset and features are in the supplement. We evaluate the robustness of the models to changes in test ordering patterns using a held-out sample of 10,000 patients which serves as the evaluation dataset.

We consider two models for predicting sepsis. The first model (classical) was trained using classical supervised learning methods, while the second (robust) was trained using the “surgery estimator” (Subbaswamy et al., 2019). While both models are random forest classifiers and were trained using the same data, the robust model was trained with the goal of being stable to shifts in test ordering patterns. Because our focus is on evaluating models rather than training them, we refer readers to the supplemental material for details about the training procedures.

Evaluating Robustness

Figure 3: (a) Accuracy of the two models as the magnitude of shift increases. Performance below the horizontal dashed line (accuracy of a naive reference classifier) is deemed unsafe. Vertical dashed line denotes the end of the plausible shift region. The robust model remains safe for all shifts in the plausible shift region, whereas the classical model does not. (b) The maximum possible test ordering rate (green) under a shift of magnitude . The original () test ordering rate was .

To evaluate robustness to changes in test ordering patterns, we estimate how the classification accuracy of the two models changes as we vary , the magnitude of the shift. When there is no shift, while as approaches 1 the shifted distribution can become increasingly different. The test ordering patterns can be represented by the conditional distribution . To make the results interpretable in the context of our application, we first map to a meaningful value: the maximum possible fraction of patients that could receive a lab test under a shift of magnitude . Figure 3(b) shows this mapping, which allows us to reason about what shifts are plausible in our context. For example, by consulting with domain experts we determine that the most aggressive plausible test ordering patterns would result in at most 50% of patients receiving a test, corresponding to a maximum of (vertical dashed line in Fig 3).

We can now evaluate how the two models perform under plausible shifts with magnitude . Figure 3(a) shows the worst-case accuracy of the classical (blue) and robust (orange) models for different shift magnitudes, with denoted by a vertical dashed line. Note that for almost all shift magnitudes, the robust model outperforms the classical model as we would expect. The red region of the plot represents accuracy levels that are deemed to be unacceptable for safe use. This performance level will be domain specific, and in this case, is set to the performance of a naive, majority class predictor.4 Importantly, for all shift magnitudes the robust model remains above our safety threshold, whereas the classical model does not. In fact, the classical model dips below the safety threshold for . Thus, we would conclude that as long as the proportion of patients with tests ordered remains below , the robust model is safe to use while the classical model may not be. This result demonstrates that the proposed method can be used to evaluate performance under varied shifts without requiring additional datasets.

Validating Conclusions: When Can We Trust the Estimates?

We now validate that the estimates provided by our method are correct by comparing the estimated worst-case performance under a particular shift to the observed in a new environment exhibiting such a shift. To do so, we used an additional dataset containing the same variables collected from a different hospital: Hospital B.5 This new hospital has a similar patient population to Hospital A, but different test ordering rates and thus matches the shift in test ordering patterns considered in the previous section.

Figure 4: Age distributions at the two hospitals are very similar.

Recall that changes in test ordering patterns correspond to changes in the distribution which hold fixed. We first provide evidence that this marginal distribution is similar between the two hospitals using univariate comparisons. The prevalence of the disease is at Hospital A vs at Hospital B. Turning to demographics, the population at Hospital A is female while at Hospital B it is female. Finally, Fig 4 shows Kernel Density Estimates of the age distributions at the two hospitals are very similar.

On the other hand, the two hospitals differ substantially in the rate of test orders ( ordering rate at Hospital A vs ordering rate at Hospital B). By mapping to the maximum proportion of patients with labs ordered, as in Section 3.2.2, we find that a shift of magnitude is necessary to produce a test ordering rate (see the dashed line in Fig 5b). Further, we had determined that increases in test ordering rate resulted in lower accuracy for the classical model. Thus, we hypothesized that the accuracy of the classical model at Hospital B should be greater than or equal to the worst-case accuracy for .

Figure 5: (a)Estimated worst-case performance of the classical model under shifts of magnitude (blue curve). The orange dot is the actual performance of the classical model when applied to Hospital B. The shaded blue region denotes a confidence interval. (b) The maximum possible test ordering rate (green) under a shift of magnitude . The original (Hospital A) test ordering rate was while at Hospital B it is .

Fig 5a plots the estimated worst-case performance of the classical model under a shift of a given magnitude. We find that the accuracy at Hospital B (orange dot at ) is, indeed, nearly equal to the worst-case accuracy for (horizontal dashed line). This supports our hypothesis and validates that the estimated worst-case risk curves allow us to make meaningful conclusions about how performance changes under worst-case shifts. The confidence intervals (blue shaded region) show that variance of the estimated risk increases with the magnitude of the shift . Thus, estimated performance under large shifts should be considered with caution.

Different Shifts, Different Conclusions

Figure 6: Accuracy of the classical and robust models under a marginal shift.

While some pertinent shifts, such as shifts in demographics, are readily expressible as marginal shifts, others, like the shifts in test ordering considered in Section 3.2.2, must be expressed using the more general conditional formulation we introduced in Section 2. Correspondingly, we need tools that can evaluate performance against the possible shifts of interest.

To demonstrate, we show that we reach different safety conclusions when we evaluate the diagnosis models for robustness to changes in the distribution . Figure 6 shows the estimated worst-case accuracy for the classical (blue) and robust (orange) models under this type of marginal shift. We can see that both models dip below the safety threshold under even minor shifts, well below the plausible magnitude threshold we previously established. Further, under this type of shift, the robust and classical models have nearly identical worst-case accuracy. Thus, had we only considered this marginal shift, we would have determined that both models were unsafe to use. By considering both types of shifts, we can instead conclude that the robust model is safe under changes in test ordering practice, but not under broad changes to the patient population. Thus, the proposed method enables evaluation under types of shifts that was not possible before.

4 Related Work

We now overview various threads of work on the problem of dataset shift, in which the deployment environment differs from the training environment.

Adapting models to new environments: One of the most common dataset shift paradigms assumes that the deployment environment is known and that we have limited access to data from the deployment environment (Quiñonero-Candela et al., 2009). Many works consider the problem of learning a model using labeled data from the training environment and unlabeled data from the deployment environment, using the unlabeled data to adjust for shifts in through reweighting (e.g., Shimodaira (2000); Huang et al. (2007)) or extracting invariant feature representations (e.g., Gong et al. (2016); Ganin et al. (2016)). Rai et al. (2010) assume that we have limited capacity to query the deployment environment and use active learning techniques to adapt a model from the training environment to the deployment environment. While these types of adaptations should absolutely be conducted when possible, our goal in this work is to evaluate how a model will perform in potential future environments from which we do not currently have samples.

Learning robust models: Another large body of research attempts to proactively improve robustness to dataset shift by learning models that are robust to changes or uncertainty in the data distribution. This work falls broadly under the umbrella of distributionally robust optimization (DRO) which, in turn, comes from a large body of work on formulating optimization problems affected by uncertainty (e.g., Ben-Tal et al. (2013); Duchi et al. (2016); Bertsimas et al. (2018)). As in our work, DRO assumes that the true population distribution is in an uncertainty “ball” around the empirical data distribution and optimizes with respect to the worst-case such distribution. In some cases, the uncertainty set is designed to reflect sampling variability and thus the radius of the ball linearly decreases with the number of samples, but no distributional shift is assumed (Namkoong and Duchi, 2016, 2017; Lei, 2020). Work on DRO explores a variety ways to define the uncertainty ball/set of distributions. Some have explored balls defined by the so-called f-divergences, which include as special cases KL divergence, divergence, and CVaR (used in this work) (Lam, 2016; Namkoong and Duchi, 2016, 2017; Duchi and Namkoong, 2018). Still others consider uncertainty sets defined by Reproducing Kernel Hilbert Spaces (RKHS) via Maximum Mean Discrepancy (MMD) (Gretton et al., 2012) and Wasserstein distances (Fournier and Guillin, 2015; Abadeh et al., 2015; Sinha et al., 2017; Esfahani and Kuhn, 2018; Lei, 2020). Unlike approaches using f-divergences, these approaches can allow for distributions with differing support, but they are computationally challenging, often requiring restrictions on the loss or transportation cost functions. Future work may consider extensions of the proposed evaluation framework to MMD and Wasserstein-based uncertainty sets.

A related line of work defines uncertainty sets of environments using causal interventions on the data generating process which allow for arbitrary strength shifts and do not have to be centered around the training distribution (Meinshausen, 2018; Bühlmann, 2020). These methods aim to learn a model with stable or robust performance across the uncertainty set of environments, assuming either that there is access to datasets collected from multiple environments (Rojas-Carulla et al., 2018; Arjovsky et al., 2019) or that a causal graph of the data generating process is known (Subbaswamy and Saria, 2018; Subbaswamy et al., 2019).

Evaluating robustness: Relatively few works have focused on evaluating the robustness of a model to distributional shift. Santurkar et al. (2020) proposed an algorithm for generating evaluation benchmarks for subpopulation shifts (new subpopulations that were unseen in the training data) by combining datasets with hierarchical class labels. Oakden-Rayner et al. (2020) considered broad, sometimes manual, strategies to evaluate the performance of a model in subpopulations. Both of these start by constructing subpopulations that have semantic meaning and evaluate the model on each of the subpopulations. Thus, discovering a subpopulation with poor performance is either serendipitous or is guided by domain knowledge. In this work, we take a data-driven approach, starting with a worst-case subpopulation and then exploring its properties.

Estimating optimal treatment subpopulations: Finally, our work is methodologically similar to that of Jeong and Namkoong (2020) and VanderWeele et al. (2019) which sought to estimate the causal effect of a treatment in the worst- and best-case subpopulations, respectively. Whereas, in our work, we seek to find a subpopulation with the highest expected conditional loss, they seek to find a subpopulation with the highest (lowest) conditional average treatment effect, which can be formulated as the best- (worst-) case average treatment effect under a marginal shift. A potential extension of this work is to consider optimal treatment subgroups defined by other types of shifts.

5 Conclusion

As machine learning systems are adopted in high impact industries such as healthcare, transportation, and finance, a growing challenge is to proactively evaluate the safety of these systems to avoid the high costs of failure. To this end, we proposed a framework and estimation method for proactively evaluating the robustness of trained machine learning models to shifts in population or setting without requiring the collection of new datasets (an often costly and time-consuming effort). Instead, the proposed framework generalizes previous work on distributional robustness, which focused on shifts in joint or marginal distributions, to include shifts in conditional distributions. Adding this capability allows us to evaluate robustness to fine-grained shifts in which particular populations or settings are held fixed under the shift. As demonstrated in the experiments, this now enables us to test robustness to realistic shifts which correspond to policies that can vary across sites, datasets, or over time. We envision that procedures like the proposed method should become standard practice for assessing the safety of the deployment of models in new settings.

Appendix A Causally Interpreting Distribution Shifts

Under certain conditions, shifts in a conditional distribution have an important interpretation as evaluating robustness to causal policy interventions or process changes (Pearl, 2009). That is, the effects of the shift corresponds to how the distribution would change under an intervention that changes the way is generated. Formally, we have the following:

Proposition 2.

Suppose the data were generated by a structural causal model (SCM) with no unobserved confounders, respecting a causal directed acyclic graph (DAG) . Then, for a single variable and set (non-descendants of W in ), a policy shift in can be expressed as a policy intervention on the mechanism generating which changes .

Proof.

Within the SCM, we have that is generated by the structural equation (where is a -specific exogenous noise random variable). A policy intervention on replaces this structural equation with a new function , which has the effect of changing to some new distribution . By the local Markov property, we have that . Thus, a shift from to can be expressed as a policy intervention from to . ∎

This result means that in order to causally interpret distribution shifts, we need to adjust for (i.e., put into ) variables that are relevant to the mechanism that generates . Fortunately, we can place additional variables into so long as they precede in a causal or topological order.

This result can be extended to the case in which the SCM contains unobserved variables. This is of practical importance because often we do not have all relevant variables recorded in the dataset (i.e., there may be unobserved confounders). In these cases, rather than a DAG, the SCM takes the graphical form of a causal acyclic directed mixed graph (ADMG) (Richardson et al., 2017). ADMGs have directed edges which represent direct causal influence (just like in DAGs), but also have bidirected edges which represent the existence of an unobserved confounder between the two endpoint variables.

We require one technical definition: we will define the Markov blanket of in an ADMG to be , where refers to the district of a variable and is the set of variables reachable through entirely bidirected paths.

Proposition 3.

Suppose the data were generated by a structural causal model (SCM), respecting a causal ADMG . Then, for a single variable and set , a policy shift in can be expressed as a policy intervention on the mechanism generating which changes .

Proof.

The proof follows just as before, noting that the local Markov property in ADMGs states that a variable is independent of all variables preceding in a topological order conditional on (Richardson et al., 2017, Section 2.8.2). ∎

Appendix B Proof of Theorem 1

In this section, we provide a proof of Theorem 1 in the main paper. This proof draws heavily on results from Chernozhukov et al. (2018) and Jeong and Namkoong (2020). For notational simplicity and consistency between our work and theirs, let be the target parameter. Algorithm 1 in the main paper is an instance of the DML2 algorithm from Chernozhukov et al. (2018) where the score function is given by

(11)

where

(12)

and where , , and . In this proof, we will show that Assumptions 3.1 and 3.2 of Chernozhukov et al. (2018) are nearly satisfied and will fill in the gaps where they are not. We restate these assumptions here with some of the notation changed to match the notation used in this paper. First, some definitions: Let , , , be some finite constants such that and let and be some positive constants converging to zero such that . Also, let be some fixed integer, and let be some sequence of sets of probability distributions of on . Let be a convex subset of some normed vector space repressenting the set of possible nuissance parameters (i.e., ). Finally, let denote that there exists a constant such that .

Assumption 2.

(Assumption 3.1 from Chernozhukov et al. (2018)) For all and , the following conditions hold. (a) The true parameter value obeys . (b) The score can be written as . (c) The map is twice continuously Gateaux-differentiable on . (d) The score obeys Neyman orthogonality. (e) The singular values of the matrix are between and .

Assumption 3.

(Assumption 3.2 from Chernozhukov et al. (2018)) For all and , the following conditions hold. (a) Given a random subset of of size , the nuisance paramter estimator belongs to the realization set with probability at least , where contains and is constrained by the next conditions. (b) The following moment conditions hold:

(c) The following conditions on the statiestical rates , , and hold:

(d) The variance of the score is non-degenerate: All eigenvalues of the matix are bounded from below by .

Here, we will show that all of these conditions are satisfied except for Assumption 2 (c) and, by extension the bound on in Assumption 3. These two conditions are used in Chernozhukov et al. (2018) to prove that, for any sequence such that , the following holds for all

(13)

where

(14)

and where is the empirical expectation w.r.t. the ’th cross-validation fold. We will prove this using other means. First, we establish that all other conditions in Assumptions 2 and 3 hold for all . For notational simplicity, we will drop the dependence of , , and on throughout. Additionally, denote by the event that .

Proof of Assumption 2 (a) This holds trivially via the definitions of and .

(15)
(16)

Proof of Assumption 2 (b) This holds trivially with .

Proof of Assumption 2 (d) To show Neyman orthogonality of , we must show that, for , the set of possible nuissance parameter values, and , the Gateaux derivative map exists for all where

and that vanishes for . For notational simplicity, let , with analgous definitions for and . Then, using Danskin’s theorem, exists for and is given by

Finally, we have

The second line follows from the definitions of and . The final line follows from the constraint that for all .

Proof of Assumption 2 (e) This hold trivially since .

Proof of Assumption 3 (a) This holds by construction of and Assumption 1.

Proof of Assumption 3 (b) The bound on holds trivially as . To bound on the event , we begin by decomposing it using the triangle inequality as

(17)
(18)

Since , and by the triangle inequality and Assumption 1, we have

(19)
(20)
(21)
(22)

where the fourth line follows from Assumption 1 (a). Next, we can bound as

(23)
(24)
(25)
(26)

where the fourth line follows from Jensen’s inequality and Assumption 1 (a). Thus, and Assumption 3 (b) holds.

Proof of Assumption 3 (c) The bound on is trivially satisfied. Further,

(27)
(28)
(29)

where the first line follows from the definition of and , the second line follows from the triangle inequality, and the third line follows from the Assumption 1. Then, to bound , first observe that where is the Iverson bracket. Then, we have

(30)