On Scalable Inference with Stochastic Gradient Descent
Abstract
In many applications involving large dataset or online updating, stochastic gradient descent (SGD) provides a scalable way to compute parameter estimates and has gained increasing popularity due to its numerical convenience and memory efficiency. While the asymptotic properties of SGDbased estimators have been established decades ago, statistical inference such as interval estimation remains much unexplored. The traditional resampling method such as the bootstrap is not computationally feasible since it requires to repeatedly draw independent samples from the entire dataset. The plugin method is not applicable when there are no explicit formulas for the covariance matrix of the estimator. In this paper, we propose a scalable inferential procedure for stochastic gradient descent, which, upon the arrival of each observation, updates the SGD estimate as well as a large number of randomly perturbed SGD estimates. The proposed method is easy to implement in practice. We establish its theoretical properties for a general class of models that includes generalized linear models and quantile regression models as special cases. The finitesample performance and numerical utility is evaluated by simulation studies and two real data applications.
Keywords: Bootstrap, Interval estimation, Generalized linear models, Large datasets, Mestimators, Quantile regression, Resampling methods, Stochastic gradient descent
1 Introduction
Big datasets arise frequently in clinical, epidemiological, financial and sociological studies. In such applications, classical optimization methods for parameter estimation such as Fisher scoring, the EM algorithm or iterated reweighted least squares (Hastie et al. 2009, Nelder & Baker 1972) do not scale well and are computationally less attractive. Due to its computational and memory efficiency, stochastic gradient descent (Robbins & Monro 1951; SGD) provides a scalable way for parameter estimation and has recently drawn a great deal of attention. Unlike classical methods that evaluate the objective function involving the entire dataset and require expensive matrix inversions, the SGD method calculates the gradient of the objective function using only one data point at a time and recursively updates the parameter estimate. This is also numerically appealing and particularly useful in online updating settings such as streaming data where it may not even be feasible to retain the entire dataset at the same time. Wang et al. (2015) gives a nice review on recent achievements of applying the SGD method to big data and streaming data.
The asymptotic properties of SGD estimators such as consistency and asymptotic normality have been established long time ago; see, for example, Ruppert (1988) and Polyak & Juditsky (1992). However, statistical inference such as confidence interval estimation for SGD estimators has remained largely unexplored. Traditional interval estimation procedures such as the plugin procedure and the bootstrap are often numerically difficult in the presence of big datasets. The bootstrap repeatedly draws samples from the entire dataset and is thus computationally prohibitive. The plugin estimator requires an explicit variancecovariance formula and involves expensive matrix inversion. In addition, the bootstrap is not applicable to the online setting where each sample arrives sequentially and it may not be necessary or feasible to store the entire dataset. Neither of them provides a scalable way for interval estimation.
As far as we know, Chen et al. (2016) is the only work that considers the statistical inference of the SGD method. Although computationally efficient, their proposed batchmeans procedure substantially underestimates the variance of the SGD estimator in finitesample studies, as shown in the simulation studies of Chen et al. (2016), because of the correlations between the batch means. In addition, the determination of the batch sizes is difficult.
In this paper, we propose a perturbationbased resampling procedure to approximate the distribution of a SGD estimator in a general class of models that include generalized linear models and quantile regression as special cases. Our proposal, justified by asymptotic theories, provides a simple way to estimate the covariance matrix and confidence regions. Through numerical experiments, we verify the ability of this procedure to give accurate inference for big datasets.
The rest of the article is organized as follows. In Section 2, we introduce the proposed perturbationbased resampling procedure for constructing confidence regions. In Section 3, we theoretically justify the validity of our proposal for a general class of models. In Section 4, we demonstrate the performance of the proposed procedures in finite samples via simulation studies and two real data applications. Some concluding remarks are given in Section 5 and all the technical proofs are relegated to the Appendix.
2 The proposed resampling procedure
Parameter estimation by optimizing an objective function is often encountered in statistical practice. Consider the general situation where the optimal model parameter is defined to be the minimizer of the expected loss function,
(1) 
where is some loss function and denotes one single observation. Suppose that the data consist of independent and identically distributed (i.i.d.) copies of , denoted by . Under mild conditions, can be consistently estimated by
(2) 
However, the minimization problem (2) for big datasets with millions of data points pose numerical challenges for classical methods such as NewtonRaphson algorithm and iteratively reweighted least squares. Furthermore, for applications such as online data where each sample arrives sequentially (e.g., search queries or transactional data), it may not be necessary or feasible to store the entire dataset, leaving alone evaluating the minimand in (2).
As a stochastic approximation method (Robbins & Monro 1951), stochastic gradient descent provides a scalable way for parameter estimation with largescale data. Given an initial estimate , the SGD method recursively updates the estimate upon the arrival of each data point ,
(3) 
where , and the learning rate with and . As suggested by Ruppert (1988) and Polyak & Juditsky (1992), the final SGD estimate is often taken as the averaging estimate,
(4) 
To order to do statistical inference with the averaging SGD estimator , we propose a perturbation resampling procedure, which recursively updates the SGD estimate as well as a large number of randomly perturbed SGD estimates, upon the arrival of each data point. Specifically, let be a set of i.i.d. nonnegative random variables with mean and variance equal to one. In parallel with (3) and (4), with , upon observing data point , we recursively updates randomly perturbed SGD estimates,
(5)  
(6) 
We will show that and converge in distribution to the same limiting distribution. In practice, these results allow us to estimate the distribution of by generating a large number, say , of random samples of . We obtain by sequentially updating perturbed SGD estimates for each sample, ,
(7)  
(8) 
and then approximate the sampling distribution of by the empirical distribution of . Specifically, the covariance matrix of can be estimated by the sample covariance matrix constructed from . Estimating the distribution of based on the distribution of leads to the construction of confidence regions for . The resulting inferential procedure retains the numerical simplicity of the SGD method, only using one pass over the data. The proposed inferential procedure scales well for datasets with millions of data points and the theoretical validity can be justified under two general model settings with mild regularity conditions as shown in the next section.
3 Theoretical Results
In this section, we derive the theoretical properties of , justifying that the conditional distribution of given data can approximate the sampling distribution of , under the following two model settings.
3.1 Model Setting 1
We first consider the setting where the objective function, in (1), is smooth. This includes linear regression, logistic regression and other generalized linear models as special cases. To ensure the consistency and asymptotic properties of the SGD estimator and the validity of the proposed resampling procedure, we assume the following assumptions.

The objective function is continuously differentiable and strongly convex with constant ; that is, for any and , .

The gradient of , , is Lipchitz continuous with constant ; that is, for any and , .

Let be the Hessian matrix of . Assume that exists and is continuous in a neighborhood of . And assume that .

Let . Let and assume for some . Assume as .
Following similar arguments in Ruppert (1988) and Polyak & Juditsky (1992), the SGD estimator is asymptotically normal under Model Setting 1.
Lemma 1.
If Assumptions A1A4 are satisfied, then we have
(9) 
By Lemma 1, we can use the plugin procedure to estimate the asymptotic covariance matrix of , where and can be conveniently estimated recursively using
(10)  
(11) 
We illustrate this setting and the regularity conditions in two examples. The data consist of , which are i.i.d. as , where denotes the response variable and be the dimensional vector of covariates. Assume that .
Example 1 (Linear regression) Suppose that , , are from the linear regression model,
(12) 
Assume are i.i.d. with , and that and are mutually independent and . Let , , and . It can be easily verified that Assumptions A1A4 hold and the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are
(13)  
(14) 
3.2 Model Setting 2
The model setting 1 includes smooth objective function in general, not necessarily restricted to the regression case. Next we consider a general regression setting that allows for nonsmooth loss functions, including quantile regression as a special case. Suppose that the data, , , are from the model (12), and the loss function is
(18) 
where is a convex function with . We require the following regularity conditions.

Assume that are i.i.d. copies of , and are mutually independent, and . Let .

Assume that is a convex function on with the right derivative being and left derivative being . Let be a function such that . There exists constant such that .

Let . Assume that , for any , and has a derivative at with . There exist constants and such that for .

Let . Assume that is finite for in a neighborhood of and is continuous at .
By Assumption B2, the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are
(19)  
(20) 
We establish the asymptotic normality of SGD estimator under model setting 2 as follows.
Lemma 2.
If Assumptions B1B4 are satisfied, then we have
(21) 
We illustrate the model setting 2 with two examples.
Example 1 (Linear regression). We revisit Example 1. Let . We have , , and . Thus, is equivalent to , and consequently and . In addition, . Therefore, the asymptotic covariance matrix in (21) is .
Example 3 (Quantile regression). Consider , where . Then , , and . Thus, is equivalent to that the quantile of is 0, and , where is the density of . Then the SGD and perturbed SGD updates for , as defined in (3) and (5) respectively, are
(22)  
(23) 
and the asymptotic covariance matrix in (21) is . As the covariance matrix involves the unknown density function, the plugin procedure is not applicable in this example.
3.3 Asymptotic properties
Let and denote the conditional probability and expectation given the data , respectively. Note that the perturbation variables satisfying that and the learning rate with and . We derive the following two theorems for Modeling Setting 1 and 2 respectively.
Theorem 1.
(Model Setting 1) If Assumptions A1A4 hold, then we have (i),
(24) 
and (ii),
(25) 
Theorem 2.
(Model setting 2) If Assumptions B1B4 hold, then we have (i),
(26) 
and (ii),
(27) 
By Theorem 1 and 2, under either Modeling Setting 1 or Model Setting 2, the KolmogorowSmirnov distance between and converges to zero in probability. This validates our proposal of the perturbationbased resampling procedure for inference with SGD.
4 Numerical results
4.1 Simulation studies
To assess the performance of the proposed perturbationbased resampling (a.k.a. random weighting; RW) procedure for SGD estimators, we conduct simulation studies for those three examples discussed in Section 3. We compare the proposed procedure with the plugin procedure, if applicable, as described in (10) and (11). We don’t compare the batchmeans procedure proposed by Chen et al. (2016), because their program is not available to public and depends on several tunings (personal communications).
Example 1 (Leastsquares regression): Consider model (12), where covariates and error are independently generated from standard normal . Here indicates the th dimension of . Let (same for the other two examples). Consider leastsquares (LS) regression and the corresponding SGD estimators are the ones defined in (13) and (14).
Example 2 (Logistic regression): Consider logistic (Logit) regression (15), where covariates are independently generated from and response is generated from Bernoulli distribution. The corresponding SGD estimators are the ones defined in (16) and (17).
Example 3 (Leastabsolutedeviation regression): Consider model (12), where covariates and error are independently generated from and respectively. Consider quantile regression with , which is equivalent to leastabsolutedeviation (LAD) regression. The corresponding SGD estimators are the ones defined in (22) and (23) with .
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.962  0.946  0.948 
Plug in  0.901  0.917  0.900  
(10000,10,6,0.2)  RW  0.940  0.948  0.953 
Plug in  0.898  0.924  0.902  
(10000,10,6,0.3)  RW  0.937  0.945  0.943 
Plug in  0.908  0.904  0.906  
(20000,20,6,0.1)  RW  0.952  0.966  0.969 
Plug in  0.893  0.882  0.902  
(20000,20,6,0.2)  RW  0.957  0.962  0.969 
Plug in  0.918  0.902  0.927  
(20000,20,6,0.3)  RW  0.965  0.954  0.961 
Plug in  0.913  0.918  0.926 
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.0157  0.0158  0.0158 
Plug in  0.0137  0.0137  0.0137  
Empirical  0.0156  0.0157  0.0158  
(10000,10,6,0.2)  RW  0.0158  0.0158  0.0158 
Plug in  0.0137  0.0137  0.0137  
Empirical  0.0164  0.0157  0.0154  
(10000,10,6,0.3)  RW  0.0158  0.0158  0.0158 
Plug in  0.0137  0.0137  0.0137  
Empirical  0.0164  0.0162  0.0163  
(20000,20,6,0.1)  RW  0.0114  0.0114  0.0114 
Plug in  0.0096  0.0096  0.0096  
Empirical  0.0114  0.0104  0.0104  
(20000,20,6,0.2)  RW  0.0114  0.0114  0.0115 
Plug in  0.0096  0.0096  0.0096  
Empirical  0.0109  0.0108  0.0105  
(20000,20,6,0.3)  RW  0.0115  0.0115  0.0114 
Plug in  0.0096  0.0096  0.0096  
Empirical  0.0108  0.0108  0.0107 
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.955  0.955  0.961 
Plug in  0.900  0.910  0.877  
(10000,10,6,0.2)  RW  0.969  0.951  0.956 
Plug in  0.887  0.878  0.878  
(10000,10,6,0.3)  RW  0.966  0.970  0.954 
Plug in  0.881  0.886  0.895  
(20000,20,6,0.1)  RW  0.947  0.960  0.943 
Plug in  0.878  0.891  0.890  
(20000,20,6,0.2)  RW  0.957  0.952  0.931 
Plug in  0.891  0.875  0.885  
(20000,20,6,0.3)  RW  0.963  0.959  0.938 
Plug in  0.861  0.853  0.861 
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.0234  0.0233  0.0232 
Plug in  0.0119  0.0119  0.0120  
Empirical  0.0228  0.0227  0.0231  
(10000,10,6,0.2)  RW  0.0246  0.0246  0.0240 
Plug in  0.0111  0.0111  0.0111  
Empirical  0.0226  0.0241  0.0229  
(10000,10,6,0.3)  RW  0.0268  0.0268  0.0254 
Plug in  0.0100  0.0100  0.0103  
Empirical  0.0250  0.0245  0.0251  
(20000,20,6,0.1)  RW  0.0158  0.0157  0.0157 
Plug in  0.0084  0.0084  0.0085  
Empirical  0.0160  0.0153  0.0156  
(20000,20,6,0.2)  RW  0.0165  0.0165  0.0161 
Plug in  0.0078  0.0078  0.0079  
Empirical  0.0161  0.0164  0.0168  
(20000,20,6,0.3)  RW  0.0182  0.0181  0.0169 
Plug in  0.0069  0.0069  0.0073  
Empirical  0.0166  0.0173  0.0170 
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.968  0.956  0.960 
Plug in  
(10000,10,6,0.2)  RW  0.958  0.953  0.966 
Plug in  
(10000,10,6,0.3)  RW  0.956  0.963  0.959 
Plug in  
(20000,20,6,0.1)  RW  0.971  0.962  0.969 
Plug in  
(20000,20,6,0.2)  RW  0.959  0.969  0.966 
Plug in  
(20000,20,6,0.3)  RW  0.953  0.959  0.960 
Plug in 
Method  Dim 1  Dim  Dim  

(10000,10,6,0.1)  RW  0.0130  0.0129  0.0129 
Plug in  
Empirical  0.0119  0.0117  0.0120  
(10000,10,6,0.2)  RW  0.0130  0.0129  0.0129 
Plug in  
Empirical  0.0121  0.0120  0.0120  
(10000,10,6,0.3)  RW  0.0129  0.0130  0.0130 
Plug in  
Empirical  0.0129  0.0117  0.0122  
(20000,20,6,0.1)  RW  0.0091  0.0090  0.0091 
Plug in  
Empirical  0.0081  0.0085  0.0081  
(20000,20,6,0.2)  RW  0.0090  0.0090  0.0090 
Plug in  
Empirical  0.0086  0.0081  0.0083  
(20000,20,6,0.3)  RW  0.0091  0.0090  0.0090 
Plug in  
Empirical  0.0083  0.0083  0.0084 
For each example, we consider six scenarios, as described by , where sample size or , number of covaraites or , number of useful covariates , and effect size or . For each example, we repeat the data generation 1000 times. For each data repetition, we use as random weights and generate copies of random weights whenever a new data point is read. Then, for each data repetition, we obtain the SGD estimator (4), apply the proposed perturbationbased resampling procedure to estimate its standard error, and apply the plugin procedure (if applicable) to estimate its standard error as well. When we calculate the average SGD estimators (4) and (6), the first 2000 estimates are excluded. Based on the estimated standard error , we can construct 95% confidence interval estimate with form of and see if it covers the true estimand. We also obtain the empirical standard error based on 1000 repeated SGD estimators, which are considered as a good approximation to the true standard error.
The coverage probabilities of the 95% confidence interval estimates constructed using our procedure (RW) and the plugin procedure (Plugin) are summarized in Tables 1, 3 and 5 for Examples 13 respectively. We only report results corresponding to the first, fourth and seventh covariates and the plugin procedure is not applicable for Example 3. From these tables, we see that the coverage probabilities from the RW procedure are close to 95%, while those from the plugin procedures are substantially smaller than 95%. Similar findings of the plugin procedure were also reported in Chen et al. (2016). Therefore, our procedure outperforms the plugin procedure.
We also compare the average estimated standard errors (SE) using the RW and plugin procedures with those empirical standard errors, which are thought to be close to the true standard error. The results are summarized in Tables 2, 4 and 6 for Examples 13 respectively. Again, we only report results corresponding to those three covariates and the plugin procedure is not applicable for Example 3. From these tables, we see that the average estimated standard errors using the RW procedure are close to those empirical standard errors, while the average estimated standard errors from the plugin procedure are substantially smaller.
4.2 Real data applications
In this section, we apply the proposed method to conduct linear regression analysis for the individual household electric power consumption dataset (POWER) and logistic regression analysis for the gas sensors for home activity monitoring dataset (GAS). Both the POWER data and the GAS data are publicly available on UCI machine learning repository.
The POWER data contains 2,075,259 observations and we fit linear regression model to investigate the relationship between the time and response variable “submetering1”, the energy submetering No. 1, in watthour of active energy, which corresponds to the kitchen, containing mainly a dishwasher, an oven and a microwave. The observations with missing value are deleted and the time are divided into 8 categories, including “02”, “35”,“68”, “911”, “1214”, “1517”, “1820” and “2123”. The GAS data constains 919,438 observations and we only use a subset containing 652,024 observations with response value being either “banana” or “wine”. We consider logistic regression model to examine the association between the response variable and 11 covariates, including time, R1 to R8, temperature and humidity.
Although standard softwares such as SAS and R can fit linear and logistic regression to such datasets without difficulty, for our illustration purpose, we use the SGD as in Example 1 and 2 to fit linear and logistic regression and use the proposed perturbationbased resampling procedure to construct confidence intervals. The point estimates and 95% confidence intervals of the coefficients are showed in Table 7 and 8, for the POWER data and the GAS data, respectively. From Table 7, we see that the electronic power consumption from kitchen is relatively high in the evening and night. From Table 8, we see that all the variables but R4 are statistical significantly associated with the response. Further, we display the histogram of perturbationbased SGD estimates for each coefficient in Figure 1 and 2for POWER data and the GAS data, respectively. The vertical line in each figure indicates the SGD estimate for one corresponding coefficient. From these figures, we see the the perturbationbased procedure can be used to estimate the whole sampling distribution, not only the standard error, of each SGD estimator.
Variable  Point estimate  95% CI 

Time 02  
Time 35  
Time 68  
Time 911  
Time 1214  
Time 1517  
Time 1820  
Time 2123 
Variable  Point estimate  95% CI 

Time  
R1  
R2  
R3  
R4  
R5  
R6  
R7  
R8  
Temperature  
Humidity 
5 Discussion
Online updating is a useful strategy for analyzing big data and streaming data, and recently stochastic gradient decent has become a popular method for doing online updating. Although the asymptotic properties of SGD have been well studied, there is little research on conducting statistical inference based on SGD estimators. In this paper, we propose the perturbationbased resampling procedure, which can be applied to estimate the sampling distribution of an SGD estimator. The offline version of perturbationbased resamping procedure was first proposed by Rubin et al. (1981) and was also discussed in Shao & Tu (2012).
The proposed resampling procedure is in essence an online version of the bootstrap. Recall that the data points, , are arriving one at a time and an SGD estimator updates itself from to whenever a new data point arrives. If we are forced to apply the bootstrap, then we should have many bootstrap samples; the data points of each bootstrap sample, , are assumed to be arriving one at a time and the SGD estimator updates itself from to whenever a new data point arrives. Of course the bootstrap is impractical here because in online updating we cannot obtain all the data points and then generate bootstrap samples. Now if we rearrange hypothetical bootstrap sample as , where follows binomial distribution , then the SGD estimator updates itself from to whenever a new batch of data points, copies of , arrives. Noting that binomial distribution approximates to Poisson distribution as , we see that the aforementioned hypothetical bootstrap is equivalent to our proposed perturbationbased resampling procure with , whose mean and variance are both equal to one.
Finally, the SGD method considered in this paper is actually the explicit SGD, in contract with the implicit SGD considered in Toulis & Airoldi (2014). We are working on extending the perturbationbased resampling procedure proposed in this paper for doing statistical inference for the implicit SGD.
Appendix
For ease exposition of establishing asymptotic normality of SGD and perturbed SGD estimates, we present the following Proposition 1, adapted from Polyak & Juditsky (1992), page 841, Theorem 2. Let be some unknown function and . The data consist of which are i.i.d. copies of . Stochastic gradients are and . With an initial point and the learning rate , the SGD estimate is defined as
(A.1) 
where , and . The regularity conditions for Proposition 1 are listed as follows.

There exists a function such that for some , , , , and all , the conditions , , , for hold true. Moreover, for all .

There exists a positive definite matrix such that for some , , and , the condition for all holds true.

is a martingale difference process, that is, almost surely, and for some ,
for all . Consider decomposition , where and . Assume that a.s.,
and there exists as such that, for all large enough,
Proposition 1. If Assumptions C1C3 are satisfied, then (i): , a.s.;
and (ii):
(A.2) 
and
Proof of Lemma 1:
By Proposition 1, it is sufficient to show that Assumptions C1C3 hold under Assumptions A1A4.
Let , , and .
C1 easily follows from Assumptions A1A3. Note that , Assumption A1 implies that , Assumption A2 implies that for , Assumption A1 implies that for , and Assumption A3 implies that over some neighborhood of . Next, we can see that Assumption A3 implies that Assumption C2 holds for , and therefore. Finally, noting that , where and , we can see that Assumptions A1 and A4 imply that the conditions in Assumption C3 about and are satisfied.
Proof of Lemma 2:
Define , and . Let and . We verify that Assumptions C1C4 hold if Assumptions B1B4 are satisfied. First, let . Assumption B2 implies that the conditions about in Assumption C1 are satisfied. Second, Assumption B3 implies that for some over a neighborhood of , , and . Third, to verify Assumption C3, we consider decomposition , where and . By Assumptions B1 and B4, we can see that and the conditions about in Assumption C3 are satisfied, and by Assumptions B2 and B4, we can see that the condition about in Assumption C3 is satisfied. Lemma 2 then follows from Proposition 1 and
(A.3) 
Proof of Theorem 1:
(i). Rewrite as
(A.4)  
where . Let denote the Borel field generated by . Since and , we have . Thus is a martingaledifference process. Let . Then , where
(A.5) 
and . Since , we have and