On Scalable Inference with Stochastic Gradient Descent

On Scalable Inference with Stochastic Gradient Descent

Yixin Fang Department of Mathematical Sciences, New Jersey Institute of Technologythanks: Corresponding to: Cullimore Hall 6th floor, NJIT, Newark, NJ 07102; Email: yixin.fang@njit.edu Jinfeng Xu Department of Statistics and Actuarial Science , Hong Kong University Lei Yang Department of Population Health, New York University School of Medicine
Abstract

In many applications involving large dataset or online updating, stochastic gradient descent (SGD) provides a scalable way to compute parameter estimates and has gained increasing popularity due to its numerical convenience and memory efficiency. While the asymptotic properties of SGD-based estimators have been established decades ago, statistical inference such as interval estimation remains much unexplored. The traditional resampling method such as the bootstrap is not computationally feasible since it requires to repeatedly draw independent samples from the entire dataset. The plug-in method is not applicable when there are no explicit formulas for the covariance matrix of the estimator. In this paper, we propose a scalable inferential procedure for stochastic gradient descent, which, upon the arrival of each observation, updates the SGD estimate as well as a large number of randomly perturbed SGD estimates. The proposed method is easy to implement in practice. We establish its theoretical properties for a general class of models that includes generalized linear models and quantile regression models as special cases. The finite-sample performance and numerical utility is evaluated by simulation studies and two real data applications.


Keywords: Bootstrap, Interval estimation, Generalized linear models, Large datasets, M-estimators, Quantile regression, Resampling methods, Stochastic gradient descent

1 Introduction

Big datasets arise frequently in clinical, epidemiological, financial and sociological studies. In such applications, classical optimization methods for parameter estimation such as Fisher scoring, the EM algorithm or iterated reweighted least squares (Hastie et al. 2009, Nelder & Baker 1972) do not scale well and are computationally less attractive. Due to its computational and memory efficiency, stochastic gradient descent (Robbins & Monro 1951; SGD) provides a scalable way for parameter estimation and has recently drawn a great deal of attention. Unlike classical methods that evaluate the objective function involving the entire dataset and require expensive matrix inversions, the SGD method calculates the gradient of the objective function using only one data point at a time and recursively updates the parameter estimate. This is also numerically appealing and particularly useful in online updating settings such as streaming data where it may not even be feasible to retain the entire dataset at the same time. Wang et al. (2015) gives a nice review on recent achievements of applying the SGD method to big data and streaming data.

The asymptotic properties of SGD estimators such as consistency and asymptotic normality have been established long time ago; see, for example, Ruppert (1988) and Polyak & Juditsky (1992). However, statistical inference such as confidence interval estimation for SGD estimators has remained largely unexplored. Traditional interval estimation procedures such as the plug-in procedure and the bootstrap are often numerically difficult in the presence of big datasets. The bootstrap repeatedly draws samples from the entire dataset and is thus computationally prohibitive. The plug-in estimator requires an explicit variance-covariance formula and involves expensive matrix inversion. In addition, the bootstrap is not applicable to the online setting where each sample arrives sequentially and it may not be necessary or feasible to store the entire dataset. Neither of them provides a scalable way for interval estimation.

As far as we know, Chen et al. (2016) is the only work that considers the statistical inference of the SGD method. Although computationally efficient, their proposed batch-means procedure substantially underestimates the variance of the SGD estimator in finite-sample studies, as shown in the simulation studies of Chen et al. (2016), because of the correlations between the batch means. In addition, the determination of the batch sizes is difficult.

In this paper, we propose a perturbation-based resampling procedure to approximate the distribution of a SGD estimator in a general class of models that include generalized linear models and quantile regression as special cases. Our proposal, justified by asymptotic theories, provides a simple way to estimate the covariance matrix and confidence regions. Through numerical experiments, we verify the ability of this procedure to give accurate inference for big datasets.

The rest of the article is organized as follows. In Section 2, we introduce the proposed perturbation-based resampling procedure for constructing confidence regions. In Section 3, we theoretically justify the validity of our proposal for a general class of models. In Section 4, we demonstrate the performance of the proposed procedures in finite samples via simulation studies and two real data applications. Some concluding remarks are given in Section 5 and all the technical proofs are relegated to the Appendix.

2 The proposed resampling procedure

Parameter estimation by optimizing an objective function is often encountered in statistical practice. Consider the general situation where the optimal model parameter is defined to be the minimizer of the expected loss function,

(1)

where is some loss function and denotes one single observation. Suppose that the data consist of independent and identically distributed (i.i.d.) copies of , denoted by . Under mild conditions, can be consistently estimated by

(2)

However, the minimization problem (2) for big datasets with millions of data points pose numerical challenges for classical methods such as Newton-Raphson algorithm and iteratively reweighted least squares. Furthermore, for applications such as online data where each sample arrives sequentially (e.g., search queries or transactional data), it may not be necessary or feasible to store the entire dataset, leaving alone evaluating the minimand in (2).

As a stochastic approximation method (Robbins & Monro 1951), stochastic gradient descent provides a scalable way for parameter estimation with large-scale data. Given an initial estimate , the SGD method recursively updates the estimate upon the arrival of each data point ,

(3)

where , and the learning rate with and . As suggested by Ruppert (1988) and Polyak & Juditsky (1992), the final SGD estimate is often taken as the averaging estimate,

(4)

To order to do statistical inference with the averaging SGD estimator , we propose a perturbation resampling procedure, which recursively updates the SGD estimate as well as a large number of randomly perturbed SGD estimates, upon the arrival of each data point. Specifically, let be a set of i.i.d. non-negative random variables with mean and variance equal to one. In parallel with (3) and (4), with , upon observing data point , we recursively updates randomly perturbed SGD estimates,

(5)
(6)

We will show that and converge in distribution to the same limiting distribution. In practice, these results allow us to estimate the distribution of by generating a large number, say , of random samples of . We obtain by sequentially updating perturbed SGD estimates for each sample, ,

(7)
(8)

and then approximate the sampling distribution of by the empirical distribution of . Specifically, the covariance matrix of can be estimated by the sample covariance matrix constructed from . Estimating the distribution of based on the distribution of leads to the construction of confidence regions for . The resulting inferential procedure retains the numerical simplicity of the SGD method, only using one pass over the data. The proposed inferential procedure scales well for datasets with millions of data points and the theoretical validity can be justified under two general model settings with mild regularity conditions as shown in the next section.

3 Theoretical Results

In this section, we derive the theoretical properties of , justifying that the conditional distribution of given data can approximate the sampling distribution of , under the following two model settings.

3.1 Model Setting 1

We first consider the setting where the objective function, in (1), is smooth. This includes linear regression, logistic regression and other generalized linear models as special cases. To ensure the consistency and asymptotic properties of the SGD estimator and the validity of the proposed resampling procedure, we assume the following assumptions.

  1. The objective function is continuously differentiable and strongly convex with constant ; that is, for any and , .

  2. The gradient of , , is Lipchitz continuous with constant ; that is, for any and , .

  3. Let be the Hessian matrix of . Assume that exists and is continuous in a neighborhood of . And assume that .

  4. Let . Let and assume for some . Assume as .

Following similar arguments in Ruppert (1988) and Polyak & Juditsky (1992), the SGD estimator is asymptotically normal under Model Setting 1.

Lemma 1.

If Assumptions A1-A4 are satisfied, then we have

(9)

By Lemma 1, we can use the plug-in procedure to estimate the asymptotic covariance matrix of , where and can be conveniently estimated recursively using

(10)
(11)

We illustrate this setting and the regularity conditions in two examples. The data consist of , which are i.i.d. as , where denotes the response variable and be the -dimensional vector of covariates. Assume that .

Example 1 (Linear regression) Suppose that , , are from the linear regression model,

(12)

Assume are i.i.d. with , and that and are mutually independent and . Let , , and . It can be easily verified that Assumptions A1-A4 hold and the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are

(13)
(14)

Example 2 (Logistic regression) Suppose that , are from the logistic regression model,

(15)

Let , , and . It can be verified that Assumptions A1-A4 hold. The SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are

(16)
(17)

3.2 Model Setting 2

The model setting 1 includes smooth objective function in general, not necessarily restricted to the regression case. Next we consider a general regression setting that allows for non-smooth loss functions, including quantile regression as a special case. Suppose that the data, , , are from the model (12), and the loss function is

(18)

where is a convex function with . We require the following regularity conditions.

  1. Assume that are i.i.d. copies of , and are mutually independent, and . Let .

  2. Assume that is a convex function on with the right derivative being and left derivative being . Let be a function such that . There exists constant such that .

  3. Let . Assume that , for any , and has a derivative at with . There exist constants and such that for .

  4. Let . Assume that is finite for in a neighborhood of and is continuous at .

By Assumption B2, the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are

(19)
(20)

We establish the asymptotic normality of SGD estimator under model setting 2 as follows.

Lemma 2.

If Assumptions B1-B4 are satisfied, then we have

(21)

We illustrate the model setting 2 with two examples.

Example 1 (Linear regression). We revisit Example 1. Let . We have , , and . Thus, is equivalent to , and consequently and . In addition, . Therefore, the asymptotic covariance matrix in (21) is .

Example 3 (Quantile regression). Consider , where . Then , , and . Thus, is equivalent to that the -quantile of is 0, and , where is the density of . Then the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are

(22)
(23)

and the asymptotic covariance matrix in (21) is . As the covariance matrix involves the unknown density function, the plug-in procedure is not applicable in this example.

3.3 Asymptotic properties

Let and denote the conditional probability and expectation given the data , respectively. Note that the perturbation variables satisfying that and the learning rate with and . We derive the following two theorems for Modeling Setting 1 and 2 respectively.

Theorem 1.

(Model Setting 1) If Assumptions A1-A4 hold, then we have (i),

(24)

and (ii),

(25)
Theorem 2.

(Model setting 2) If Assumptions B1-B4 hold, then we have (i),

(26)

and (ii),

(27)

By Theorem 1 and 2, under either Modeling Setting 1 or Model Setting 2, the Kolmogorow-Smirnov distance between and converges to zero in probability. This validates our proposal of the perturbation-based resampling procedure for inference with SGD.

4 Numerical results

4.1 Simulation studies

To assess the performance of the proposed perturbation-based resampling (a.k.a. random weighting; RW) procedure for SGD estimators, we conduct simulation studies for those three examples discussed in Section 3. We compare the proposed procedure with the plug-in procedure, if applicable, as described in (10) and (11). We don’t compare the batch-means procedure proposed by Chen et al. (2016), because their program is not available to public and depends on several tunings (personal communications).

Example 1 (Least-squares regression): Consider model (12), where covariates and error are independently generated from standard normal . Here indicates the -th dimension of . Let (same for the other two examples). Consider least-squares (LS) regression and the corresponding SGD estimators are the ones defined in (13) and (14).

Example 2 (Logistic regression): Consider logistic (Logit) regression (15), where covariates are independently generated from and response is generated from Bernoulli distribution. The corresponding SGD estimators are the ones defined in (16) and (17).

Example 3 (Least-absolute-deviation regression): Consider model (12), where covariates and error are independently generated from and respectively. Consider quantile regression with , which is equivalent to least-absolute-deviation (LAD) regression. The corresponding SGD estimators are the ones defined in (22) and (23) with .

Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.962 0.946 0.948
Plug in 0.901 0.917 0.900
(10000,10,6,0.2) RW 0.940 0.948 0.953
Plug in 0.898 0.924 0.902
(10000,10,6,0.3) RW 0.937 0.945 0.943
Plug in 0.908 0.904 0.906
(20000,20,6,0.1) RW 0.952 0.966 0.969
Plug in 0.893 0.882 0.902
(20000,20,6,0.2) RW 0.957 0.962 0.969
Plug in 0.918 0.902 0.927
(20000,20,6,0.3) RW 0.965 0.954 0.961
Plug in 0.913 0.918 0.926
Table 1: Coverage probabilities of 95% confidence intervals for LS regression.
Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.0157 0.0158 0.0158
Plug in 0.0137 0.0137 0.0137
Empirical 0.0156 0.0157 0.0158
(10000,10,6,0.2) RW 0.0158 0.0158 0.0158
Plug in 0.0137 0.0137 0.0137
Empirical 0.0164 0.0157 0.0154
(10000,10,6,0.3) RW 0.0158 0.0158 0.0158
Plug in 0.0137 0.0137 0.0137
Empirical 0.0164 0.0162 0.0163
(20000,20,6,0.1) RW 0.0114 0.0114 0.0114
Plug in 0.0096 0.0096 0.0096
Empirical 0.0114 0.0104 0.0104
(20000,20,6,0.2) RW 0.0114 0.0114 0.0115
Plug in 0.0096 0.0096 0.0096
Empirical 0.0109 0.0108 0.0105
(20000,20,6,0.3) RW 0.0115 0.0115 0.0114
Plug in 0.0096 0.0096 0.0096
Empirical 0.0108 0.0108 0.0107
Table 2: Averaged estimated SE and empirical SE for LS regression.
Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.955 0.955 0.961
Plug in 0.900 0.910 0.877
(10000,10,6,0.2) RW 0.969 0.951 0.956
Plug in 0.887 0.878 0.878
(10000,10,6,0.3) RW 0.966 0.970 0.954
Plug in 0.881 0.886 0.895
(20000,20,6,0.1) RW 0.947 0.960 0.943
Plug in 0.878 0.891 0.890
(20000,20,6,0.2) RW 0.957 0.952 0.931
Plug in 0.891 0.875 0.885
(20000,20,6,0.3) RW 0.963 0.959 0.938
Plug in 0.861 0.853 0.861
Table 3: Coverage probabilities of 95% confidence intervals for Logit regression.
Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.0234 0.0233 0.0232
Plug in 0.0119 0.0119 0.0120
Empirical 0.0228 0.0227 0.0231
(10000,10,6,0.2) RW 0.0246 0.0246 0.0240
Plug in 0.0111 0.0111 0.0111
Empirical 0.0226 0.0241 0.0229
(10000,10,6,0.3) RW 0.0268 0.0268 0.0254
Plug in 0.0100 0.0100 0.0103
Empirical 0.0250 0.0245 0.0251
(20000,20,6,0.1) RW 0.0158 0.0157 0.0157
Plug in 0.0084 0.0084 0.0085
Empirical 0.0160 0.0153 0.0156
(20000,20,6,0.2) RW 0.0165 0.0165 0.0161
Plug in 0.0078 0.0078 0.0079
Empirical 0.0161 0.0164 0.0168
(20000,20,6,0.3) RW 0.0182 0.0181 0.0169
Plug in 0.0069 0.0069 0.0073
Empirical 0.0166 0.0173 0.0170
Table 4: Averaged estimated SE and empirical SE for Logit regression.
Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.968 0.956 0.960
Plug in
(10000,10,6,0.2) RW 0.958 0.953 0.966
Plug in
(10000,10,6,0.3) RW 0.956 0.963 0.959
Plug in
(20000,20,6,0.1) RW 0.971 0.962 0.969
Plug in
(20000,20,6,0.2) RW 0.959 0.969 0.966
Plug in
(20000,20,6,0.3) RW 0.953 0.959 0.960
Plug in
Table 5: Coverage probabilities for 95% confidence intervals for LAD regression.
Method Dim 1 Dim Dim
(10000,10,6,0.1) RW 0.0130 0.0129 0.0129
Plug in
Empirical 0.0119 0.0117 0.0120
(10000,10,6,0.2) RW 0.0130 0.0129 0.0129
Plug in
Empirical 0.0121 0.0120 0.0120
(10000,10,6,0.3) RW 0.0129 0.0130 0.0130
Plug in
Empirical 0.0129 0.0117 0.0122
(20000,20,6,0.1) RW 0.0091 0.0090 0.0091
Plug in
Empirical 0.0081 0.0085 0.0081
(20000,20,6,0.2) RW 0.0090 0.0090 0.0090
Plug in
Empirical 0.0086 0.0081 0.0083
(20000,20,6,0.3) RW 0.0091 0.0090 0.0090
Plug in
Empirical 0.0083 0.0083 0.0084
Table 6: Averaged estimated SE and empirical SE for LAD regression.

For each example, we consider six scenarios, as described by , where sample size or , number of covaraites or , number of useful covariates , and effect size or . For each example, we repeat the data generation 1000 times. For each data repetition, we use as random weights and generate copies of random weights whenever a new data point is read. Then, for each data repetition, we obtain the SGD estimator (4), apply the proposed perturbation-based resampling procedure to estimate its standard error, and apply the plug-in procedure (if applicable) to estimate its standard error as well. When we calculate the average SGD estimators (4) and (6), the first 2000 estimates are excluded. Based on the estimated standard error , we can construct 95% confidence interval estimate with form of and see if it covers the true estimand. We also obtain the empirical standard error based on 1000 repeated SGD estimators, which are considered as a good approximation to the true standard error.

The coverage probabilities of the 95% confidence interval estimates constructed using our procedure (RW) and the plug-in procedure (Plug-in) are summarized in Tables 1, 3 and 5 for Examples 1-3 respectively. We only report results corresponding to the first, fourth and seventh covariates and the plug-in procedure is not applicable for Example 3. From these tables, we see that the coverage probabilities from the RW procedure are close to 95%, while those from the plug-in procedures are substantially smaller than 95%. Similar findings of the plug-in procedure were also reported in Chen et al. (2016). Therefore, our procedure outperforms the plug-in procedure.

We also compare the average estimated standard errors (SE) using the RW and plug-in procedures with those empirical standard errors, which are thought to be close to the true standard error. The results are summarized in Tables 2, 4 and 6 for Examples 1-3 respectively. Again, we only report results corresponding to those three covariates and the plug-in procedure is not applicable for Example 3. From these tables, we see that the average estimated standard errors using the RW procedure are close to those empirical standard errors, while the average estimated standard errors from the plug-in procedure are substantially smaller.

4.2 Real data applications

In this section, we apply the proposed method to conduct linear regression analysis for the individual household electric power consumption dataset (POWER) and logistic regression analysis for the gas sensors for home activity monitoring dataset (GAS). Both the POWER data and the GAS data are publicly available on UCI machine learning repository.

The POWER data contains 2,075,259 observations and we fit linear regression model to investigate the relationship between the time and response variable “sub-metering-1”, the energy sub-metering No. 1, in watt-hour of active energy, which corresponds to the kitchen, containing mainly a dishwasher, an oven and a microwave. The observations with missing value are deleted and the time are divided into 8 categories, including “0-2”, “3-5”,“6-8”, “9-11”, “12-14”, “15-17”, “18-20” and “21-23”. The GAS data constains 919,438 observations and we only use a subset containing 652,024 observations with response value being either “banana” or “wine”. We consider logistic regression model to examine the association between the response variable and 11 covariates, including time, R1 to R8, temperature and humidity.

Although standard softwares such as SAS and R can fit linear and logistic regression to such datasets without difficulty, for our illustration purpose, we use the SGD as in Example 1 and 2 to fit linear and logistic regression and use the proposed perturbation-based resampling procedure to construct confidence intervals. The point estimates and 95% confidence intervals of the coefficients are showed in Table 7 and 8, for the POWER data and the GAS data, respectively. From Table 7, we see that the electronic power consumption from kitchen is relatively high in the evening and night. From Table 8, we see that all the variables but R4 are statistical significantly associated with the response. Further, we display the histogram of perturbation-based SGD estimates for each coefficient in Figure 1 and 2for POWER data and the GAS data, respectively. The vertical line in each figure indicates the SGD estimate for one corresponding coefficient. From these figures, we see the the perturbation-based procedure can be used to estimate the whole sampling distribution, not only the standard error, of each SGD estimator.

Variable Point estimate 95% CI
Time 0-2
Time 3-5
Time 6-8
Time 9-11
Time 12-14
Time 15-17
Time 18-20
Time 21-23
Table 7: Point estimates and 95% confidence intervals of the coefficients for the POWER data.
Variable Point estimate 95% CI
Time
R1
R2
R3
R4
R5
R6
R7
R8
Temperature
Humidity
Table 8: Point estimates and 95% confidence intervals of the coefficients for the GAS data.
Figure 1: Histograms of perturbation-based SGD estimates for the POWER data.
Figure 2: Histograms of perturbation-based SGD estimates for the GAS data.

5 Discussion

Online updating is a useful strategy for analyzing big data and streaming data, and recently stochastic gradient decent has become a popular method for doing online updating. Although the asymptotic properties of SGD have been well studied, there is little research on conducting statistical inference based on SGD estimators. In this paper, we propose the perturbation-based resampling procedure, which can be applied to estimate the sampling distribution of an SGD estimator. The offline version of perturbation-based resamping procedure was first proposed by Rubin et al. (1981) and was also discussed in Shao & Tu (2012).

The proposed resampling procedure is in essence an online version of the bootstrap. Recall that the data points, , are arriving one at a time and an SGD estimator updates itself from to whenever a new data point arrives. If we are forced to apply the bootstrap, then we should have many bootstrap samples; the data points of each bootstrap sample, , are assumed to be arriving one at a time and the SGD estimator updates itself from to whenever a new data point arrives. Of course the bootstrap is impractical here because in online updating we cannot obtain all the data points and then generate bootstrap samples. Now if we rearrange hypothetical bootstrap sample as , where follows binomial distribution , then the SGD estimator updates itself from to whenever a new batch of data points, copies of , arrives. Noting that binomial distribution approximates to Poisson distribution as , we see that the aforementioned hypothetical bootstrap is equivalent to our proposed perturbation-based resampling procure with , whose mean and variance are both equal to one.

Finally, the SGD method considered in this paper is actually the explicit SGD, in contract with the implicit SGD considered in Toulis & Airoldi (2014). We are working on extending the perturbation-based resampling procedure proposed in this paper for doing statistical inference for the implicit SGD.

Appendix

For ease exposition of establishing asymptotic normality of SGD and perturbed SGD estimates, we present the following Proposition 1, adapted from Polyak & Juditsky (1992), page 841, Theorem 2. Let be some unknown function and . The data consist of which are i.i.d. copies of . Stochastic gradients are and . With an initial point and the learning rate , the SGD estimate is defined as

(A.1)

where , and . The regularity conditions for Proposition 1 are listed as follows.

  1. There exists a function such that for some , , , , and all , the conditions , , , for hold true. Moreover, for all .

  2. There exists a positive definite matrix such that for some , , and , the condition for all holds true.

  3. is a martingale difference process, that is, almost surely, and for some ,

    for all . Consider decomposition , where and . Assume that a.s.,

    and there exists as such that, for all large enough,

Proposition 1. If Assumptions C1-C3 are satisfied, then (i): ,  a.s.;
and (ii):

(A.2)

and

Proof of Lemma 1:
By Proposition 1, it is sufficient to show that Assumptions C1-C3 hold under Assumptions A1-A4. Let , , and . C1 easily follows from Assumptions A1-A3. Note that , Assumption A1 implies that , Assumption A2 implies that for , Assumption A1 implies that for , and Assumption A3 implies that over some neighborhood of . Next, we can see that Assumption A3 implies that Assumption C2 holds for , and therefore. Finally, noting that , where and , we can see that Assumptions A1 and A4 imply that the conditions in Assumption C3 about and are satisfied.
Proof of Lemma 2:
Define , and . Let and . We verify that Assumptions C1-C4 hold if Assumptions B1-B4 are satisfied. First, let . Assumption B2 implies that the conditions about in Assumption C1 are satisfied. Second, Assumption B3 implies that for some over a neighborhood of , , and . Third, to verify Assumption C3, we consider decomposition , where and . By Assumptions B1 and B4, we can see that and the conditions about in Assumption C3 are satisfied, and by Assumptions B2 and B4, we can see that the condition about in Assumption C3 is satisfied. Lemma 2 then follows from Proposition 1 and

(A.3)

Proof of Theorem 1:
(i). Rewrite as

(A.4)

where . Let denote the Borel field generated by . Since and , we have . Thus is a martingale-difference process. Let . Then , where

(A.5)

and . Since , we have and