Learning Structured Distributions From Untrusted Batches: Faster and Simpler

Learning Structured Distributions From Untrusted Batches: Faster and Simpler

Abstract

We revisit the problem of learning from untrusted batches introduced by Qiao and Valiant [qiao2017learning]. Recently, Jain and Orlitsky [jain2019robust] gave a simple semidefinite programming approach based on the cut-norm that achieves essentially information-theoretically optimal error in polynomial time. Concurrently, Chen et al. [chen2019efficiently] considered a variant of the problem where is assumed to be structured, e.g. log-concave, monotone hazard rate, -modal, etc. In this case, it is possible to achieve the same error with sample complexity sublinear in , and they exhibited a quasi-polynomial time algorithm for doing so using Haar wavelets.

In this paper, we find an appealing way to synthesize the techniques of [jain2019robust] and [chen2019efficiently] to give the best of both worlds: an algorithm which runs in polynomial time and can exploit structure in the underlying distribution to achieve sublinear sample complexity. Along the way, we simplify the approach of [jain2019robust] by avoiding the need for SDP rounding and giving a more direct interpretation of it through the lens of soft filtering, a powerful recent technique in high-dimensional robust estimation.

1 Introduction

In this paper, we consider the problem of learning structured distributions from untrusted batches. This is a variant on the problem of learning from untrusted batches, as introduced in [qiao2017learning]. Here, there is an unknown distribution over , and we are given batches of samples, each of size . An -fraction of these batches are “good,” and consist of i.i.d. samples from some distribution with distance at most from in total variation distance,1 but an -fraction of these batches are “bad,” and can be adversarially corrupted. The goal then is to estimate in total variation distance.

This problem models a situation where we get batches of data from many different users, for instance, in a crowdsourcing application. Each honest user provides a relatively small batch of data, which is by itself insufficient to learn a good model, and moreover, can come from slightly different distributions depending on the user, due to heterogeneity. At the same time, a non-trivial fraction of data can come from malicious users who wish to game our algorithm to their own ends. The high level question is whether or not we can exploit the batch structure of our data to improve the robustness of our estimator.

For this problem, there are three separate, but equally important, metrics under which we can evaluate any estimator:

Robustness

How accurately can we estimate in total variation distance?

Runtime

Are there algorithms that run in polynomial time in all the relevant parameters?

Sample complexity

How few samples do we need in order to estimate ?

In the original paper, Qiao and Valiant [qiao2017learning] focus primarily on robustness. They give an algorithm for learning general from untrusted batches that uses a polynomial number of samples, and estimates to within

in total variation distance, and they proved that this is the best possible up to constant factors. However, their estimator runs in time . Qiao and Valiant [qiao2017learning] also gave an time algorithm based on low-rank tensor approximation, however their algorithm also needs samples.

A natural question is whether or not this robustness can be achieved efficiently. In [chen2019efficiently] we gave an time algorithm with sample complexity for the general problem based on the sum-of-squares hierarchy. It estimates to within

in total variation distance. In concurrent and independent work Jain and Orlitsky [jain2019robust] gave a polynomial time algorithm based on a much simpler semidefinite program that estimates to within the same total variation distance. Their approach was based on an elegant way to combine approximation algorithms for the cut-norm [alon2004approximating] with the filtering approach for robust estimation [diakonikolas2019robust, steinhardt2018resilience, diakonikolas2017being, diakonikolas2018sever, dong2019quantum].

To some extent, the results of [chen2019efficiently, jain2019robust] also address the third consideration, sample complexity. In particular, the estimator of [jain2019robust] requires batches to achieve the error rate mentioned above. Without any assumptions on the structure of , even in the case where there are no corruptions, any algorithm must take at least batches of size are required in order to learn to within total variation distance . Thus, this sample complexity is nearly-optimal for this problem, unless we make additional assumptions.

Unfortunately, in many cases, the domain size can be very large, and a sample complexity which strongly grows with can render the estimator impractical. However in most applications, we have prior knowledge about the shape of that could in principle be used to drastically reduce the sample complexity. For example, if is log-concave, monotone or multimodal with a bounded number of modes, it is known that can be approximated by a piecewise polynomial function and when there are no corruptions, this meta structural property can be used reduce the sample complexity to logarithmic in the domain size [chan2014efficient]. An appealing aspect of the relaxation in [chen2019efficiently] was that it was possible to incorporate shape-constraints into the relaxation, through the Haar wavelet basis, which allowed us to improve the running time and sample complexity to quasipolynomial in and , respectively the degree and number of parts in the piecewise polynomial approximation, and quasipolylogarithmic in . Unfortunately, while [jain2019robust] achieves better runtime and sample complexity in the unstructured setting, their techniques do not obviously extend to obtain a similar sample complexity under structural assumptions.

This raises a natural question: can we build on [jain2019robust] and [chen2019efficiently], to incorporate shape constraints into a simple semidefinite programming approach, that can achieve nearly-optimal robustness, in polynomial runtime, and with sample complexity which is sublinear in ? In this paper, we answer this question in the affirmative: {theorem}[Informal, see Theorem 3] Let be a distribution over that is approximated by an -part piecewise polynomial function with degree at most . Then there is a polynomial-time algorithm which estimates to within

(1)

in total variation distance after drawing -corrupted batches, each of size , where

(2)

is the number of batches needed. Any algorithm for learning structured distributions from untrusted batches must take at least batches to achieve error , and an interesting open question is whether there is a polynomial time algorithm that achieves these bounds. For robustly estimating the mean of a Gaussian in high-dimensions, there is evidence for a gap between the best possible estimation error and what can be achieved by polynomial time algorithms [diakonikolas2017statistical]. It seems plausible that the gap between the best possible estimation error and what we achieve is unavoidable in this setting as well.

1.1 High-Level Argument

[jain2019robust] demonstrated how to learn general distributions from untrusted batches in polynomial time using a filtering algorithm similar to those found in [diakonikolas2019robust, steinhardt2018resilience, diakonikolas2017being, diakonikolas2018sever, dong2019quantum], and in [chen2019efficiently] it was shown how to learn structured distributions from untrusted batches in quasipolynomial time using an SoS relaxation based on Haar wavelets.

In this work we show how to combine the filtering framework of [jain2019robust] with the Haar wavelet technology of [chen2019efficiently] to obtain a polynomial-time, sample-efficient algorithm for learning structured distributions from untrusted batches. In the discussion in this section, we will specialize to the case of for the sake of clarity.

Learning via Filtering

A useful first observation is that the problem of learning from untrusted batches can be thought of as robust mean estimation of multinomial distributions in distance: given a batch of samples from a distribution over , the frequency vector is distributed according to the normalized multinomial distribution given by draws from . Note that is precisely the mean of , so the problem of estimating from an -corrupted set of frequency vectors is equivalent to that of robustly estimating the mean of a multinomial distribution.

As such, it is natural to try to adapt the existing algorithms for robust mean estimation of other distributions; the fastest of these are based on a simple filtering approach which works as follows. We maintain weights for each point, initialized to uniform. At every step, we measure the maximum “skew” of the weighted dataset in any direction, and if this skew is still too high, update the weights by

  1. Finding the direction in which the corruptions “skew” the dataset the most.

  2. Giving a “score” to each point based on how badly it skews the dataset in the direction

  3. Downweighting or removing points with high scores.

Otherwise, if the skew is low, output the empirical mean of the weighted dataset.

To prove correctness of this procedure, one must show three things for the particular skewness measure and score function chosen:

  • Regularity: For any sufficiently large collection of -corrupted samples, a particular deterministic regularity condition holds (Definition 3.3 and Lemma 3.3)

  • Soundness: Under the regularity condition, if the skew of the weighted dataset is small, then the empirical mean of the weighted dataset is sufficiently close to the true mean (Lemma 3.4).

  • Progress: Under the regularity condition, if the skew of the weighted dataset is large, then one iteration of the above update scheme will remove more weight from the bad samples than from the good samples (Lemma 3.5).

For isotropic Gaussians, skewness is just given by the maximum variance of the weighted dataset in any direction, i.e. where is the empirical covariance of the weighted dataset. Given maximizing , the “score” of a point is then simply its contribution to the skewness.

To learn in distance, the right set of test vectors to use is the Hamming cube , so a natural attempt at adapting the above skewness measure to robust mean estimation of multinomials is to consider the quantity . But one of the key challenges in passing from isotropic Gaussians to multinomial distributions is that this quantity above is not very informative because we do not have a good handle on the covariance of . In particular, it could be that for a direction , is high simply because the good points have high variance to begin with.

The Jain-Orlitsky Correction Term

The clever workaround of [jain2019robust] was to observe that we know exactly what the projection of a multinomial distribution in any direction is, namely . And so to discern whether the corrupted points skew our estimate in a given direction , one should measure not the variance in the direction , but rather the following corrected quantity: the variance in the direction , minus what the variance would be if the distribution of the projections in the direction were actually given by , where is the empirical mean of the weighted dataset. This new skewness measure can be written as

(3)

Finding the direction which maximizes this corrected quantity is some Boolean quadratic programming problem which can be solved approximately by solving the natural SDP relaxation and rounding to a Boolean vector using the machinery of [alon2004approximating]. Using this approach, [jain2019robust] obtained a polynomial-time algorithm for learning general discrete distributions from untrusted batches.

Learning Structured Distributions

[chen2019efficiently] introduced the question of learning from untrusted batches when the distribution is known to be structured. Learning structured distributions in the classical sense is well-understood: if a distribution is -close in total variation distance to being -piecewise degree-, then to estimate in total variation distance it is enough to approximate in a much weaker norm which we will denote by , where is a parameter that depends on and . We review the details for this in Section 2.4.

In [chen2019efficiently] we gave a sum-of-squares algorithm for robust mean estimation in the norm that achieved error in quasipolynomial time, and a natural open question was to achieve this with a polynomial-time algorithm.

The key challenge that [chen2019efficiently] had to address was that unlike the Hamming cube or , it is unclear how to optimize over the set of test vectors dual to the norm. Combinatorially, this set is easy to characterize: is small if and only if is small for all , where is the set of all with at most sign changes when read as a vector from left to right (for example, ).

The main observation in [chen2019efficiently] is that vectors with few sign changes admit sparse representations in the Haar wavelet basis, so instead of working with , one can simply work with a convex relaxation of this Haar-sparsity constraint. As such, if we let denote the relaxation of the set of to all matrices whose Haar transforms are “analytically sparse” in some appropriate, convex sense (see Section 2.6 for a formal definition), then as this set of test matrices contains the set of test matrices for , it is enough to learn in the norm associated to , which is strictly stronger than the norm.2

Our goal then is to produce for which is small. And even though is a stronger norm, it turns out that the metric entropy of is still small enough that one can get good sample complexity guarantees. Indeed, showing that this is the case (see Lemma A.1) was where the bulk of the technical machinery of [chen2019efficiently] went, and as we elaborate on in Appendix B, the analysis there left some room for tightening. In this work, we give a refined analysis of which allows us to get nearly tight sample complexity bounds.

Putting Everything Together

Almost all of the pieces are in place to instantiate the filtering framework: in lieu of the quantity in (3), which can be phrased as the maximization of some quadratic over , where depends on the dataset and the weights on its points,3 we can define our skewness measure as , and we can define the score for each point in the dataset to be its contribution to the skewness measure (see Section 3.2).

At this point the reader may be wondering why we never round to an actual vector before computing skewness and scores. As our subsequent analysis will show, it turns out that rounding is unnecessary, both in our setting and even in the unstructured distribution setting considered in [jain2019robust]. Indeed, if one examines the three proof ingredients of regularity, soundness, and progress that we enumerated above, it becomes evident that the filtering framework for robust mean estimation does not actually require finding a concrete direction in in which to filter, merely a skewness measure and score functions which are amenable to showing the above three statements. That said, as we will see, it becomes more technically challenging to prove these ingredients when is not rounded to an actual direction (see e.g. the discussion after Lemmas A.1 and A.1 in Appendix A), though nevertheless possible. We hope that this observation will prove useful in future applications of filtering.

1.2 Related work

The problem of learning from untrusted batches was introduced by [qiao2017learning], and is motivated by problems in reliable distributed learning and federated learning [44822, konevcny2016federated]. The general question of learning from batches has been considered in a number of settings [levi2013testing, tian2017learning] in the theoretical computer science community, but these algorithms do not work in the presence of adversarial noise.

The study of univariate shape constrained density estimation has a long history in statistics and computer science, and we cannot hope to do justice to it here. See [barlow1972statistical] for a survey of classical results in the area, and [o2016nonparametric, diakonikolas2016learning] for a survey of more recent results in this area. Of particular relevance to us are the techniques based on the classical piecewise polynomial (or spline) methods, see e.g. [WegW83, Stone94, Stone97, WillettN07]. Recent work, which we build off of, demonstrates that this framework is capable of achieving nearly-optimal sample complexity and runtime, for a large class of structured distributions [chan2013learning, chan2014efficient, CDSS14b, ADHLS15, acharya2017sample].

Our techniques are also related to a recent line of work on robust statistics [diakonikolas2019robust, lai2016agnostic, charikar2017learning, diakonikolas2017being, hopkins2018mixture, kothari2018robust], a classical problem dating back to the 60s and 70s [anscombe1960rejection, tukey1960survey, huber1992robust, tukey1975mathematics]. See [li2018principled, steinhardt2018robust, diakonikolas2019recent] for a more comprehensive survey of this line of work.

Finally, the most relevant papers to our result are [chen2019efficiently, jain2019robust], which improve upon the result of [qiao2017learning] in terms of runtime and sample complexity. As mentioned above, our result can be thought of as a way to combine the improved filtering algorithm of [jain2019robust] and the shape-constrained technology introduced in [chen2019efficiently].

2 Technical Preliminaries

2.1 Notation

  • Given , let denote the normalized binomial distribution, which takes values in rather than .

  • Let be the simplex of nonnegative vectors whose coordinates sum to 1. Any naturally corresponds to a probability distribution over .

  • Let denote the all-ones vector. We omit the subscript when the context is clear.

  • Given matrix , let denote the maximum absolute value of any entry in , let denote the absolute sum of its entries, and let denote its Frobenius norm.

  • Given , let denote the distribution over given by sampling a frequency vector from the multinomial distribution arising from draws from the distribution over specified by , and dividing by .

  • Given samples and , define to be the set of weights which assigns to all points in and 0 to all other points. Also define its normalization . Let denote the set of weights which are convex combinations of such weights for . Given , define , and define , that is, the empirical mean of the samples indexed by .

  • Given samples , weights , and , define the matrices

    (4)

    When , denote these matrices by and and note that

    (5)

    Also define and . We will also denote by and by .

    To get intuition for these definitions, note that any bitstring corresponding to induces a normalized binomial distribution , and any sample induces a corresponding sample from . Then is the difference between the empirical variance of and the variance of the binomial distribution .

2.2 The Generative Model

Throughout the rest of the paper, let , , and let be some probability distribution over .

{definition}

We say is an -corrupted -diverse set of batches of size from if they are generated via the following process:

  • For every , is a set of iid draws from , where is some probability distribution over for which .

  • A computationally unbounded adversary inspects and adds arbitrarily chosen tuples , and returns the entire collection of tuples in any arbitrary order as .

Let denote the indices of the uncorrupted (good) and corrupted (bad) batches.

It turns out that we might as well treat each as an unordered tuple. That is, for any , define to be the vector of frequencies whose -th entry is for all . Then for each, , is an independent draw from . Henceforth, we will work solely in this frequency vector perspective.

2.3 Elementary Facts

In this section we collect miscellaneous elementary facts that will be useful in subsequent sections.

{fact}

For , weights , , , and symmetric,

(6)

In particular, by taking for any ,

(7)

That is, the function is minimized over by .

Proof.

Without loss of generality we may assume . Using the fact that for symmetric , we see that

Because , we see that

(8)

from which (6) follows. The remaining parts of the claim follow trivially. ∎

{fact}

For any , let weights satisfy . If is the set of weights defined by for and otherwise, and if , then we have that .

Proof.

We may write

(9)
(10)

where the first step follows by definition of and by triangle inequality, the second step follows by the fact that , and the third step follows by the fact that , while as the samples lie in . ∎

It will be useful to have a basic bound on the Frobenius norm of .

{lemma}

For any and any weights for which , we have that .

Proof.

For any sample , we have that

(11)

and

(12)

from which the lemma follows by triangle inequality and the assumption that . ∎

2.4 Norms and VC Complexity

In this section we review basics about learning distributions which are close to piecewise polynomial.

{definition}

[ norms, see e.g. [devroye2001combinatorial]] For positive integers , define to be the set of all unions of at most disjoint intervals over , where an interval is any subset of of the form . The distance between two distributions over is

(13)

Equivalently, say that has sign changes if there are exactly indices for which . Then if denotes the set of all such , we have

(14)

Note that

(15)
{definition}

We say that a distribution over is -piecewise degree- if there is a partition of into disjoint intervals , together with univariate degree- polynomials and a distribution on , such that and such that for all , for all in .

A proof of the following lemma, a consequence of [acharya2017sample], can be found in [chen2019efficiently].

{lemma}

[Lemma 5.1 in [chen2019efficiently], follows by [acharya2017sample]] Let . If is -piecewise degree- and , then there is an algorithm which, given the vector , outputs a distribution for which in time .

Henceforth, we will focus solely on the problem of learning in norm, where

(16)

2.5 Haar Wavelets

We briefly recall the definition of Haar wavelets, further details and examples of which can be found in [chen2019efficiently].

{definition}

Let be a positive integer and let . The Haar wavelet basis is an orthonormal basis over consisting of the father wavelet , the mother wavelet (where contains 1’s and -1’s), and for every for which and , the wavelet whose -th coordinates are and whose -th coordinates are , and whose remaining coordinates are 0.

Additionally, we will use the following notation when referring to Haar wavelets:

  • Let denote the matrix whose rows consist of the vectors of the Haar wavelet basis for . When the context is clear, we will omit the subscript and refer to this matrix as .

  • For , if the -th element of the Haar wavelet basis for is some , then define the weight .

  • For any index , let denote the set of indices for which the -th Haar wavelet is of the form for some .

  • Given any , define the Haar-weighted norm on by , where for every , . Likewise, given any norm on , define the Haar-weighted -norm on by , where for every , .

The key observation is that any with at most sign changes, where is given by (16), has an -sparse representation in the Haar wavelet basis. We will use the following fundamental fact about Haar wavelets, part of which appears as Lemma 6.3 in [chen2019efficiently].

{lemma}

Let have at most sign changes. Then has at most nonzero entries, and furthermore . In particular, .

Proof.

We first show that has at most nonzero entries. For any with nonzero entries at indices and such that , if has no sign change in the interval , then . For every index at which has a sign change, there are at most choices of for which has a nonzero entry at index , from which the claim follows by a union bound over all choices of , together with the fact that may be nonzero.

Now for each for which , note that

(17)

as claimed. The bounds on follow immediately. ∎

2.6 Convex Relaxation For Finding the Direction of Largest Variance

Recall that in [jain2019robust], the authors consider the binary optimization problem . We would like to approximate the optimization problem . Motivated by [chen2019efficiently] and Lemma 2.5, we consider the following convex relaxation:

{definition}

Let be given by (16). Let denote the (convex) set of all matrices for which

  1. .

  2. .

  3. .

  4. .

  5. .

Let denote the associated norm given by . By abuse of notation, for vectors we will also use to denote .

Because has an efficient separation oracle, one can compute in polynomial time.

{remark}

Note that, besides not being a sum-of-squares program like the one considered in [chen2019efficiently], this relaxation is also slightly different because of Constraints 3 and 4. As we will see in Section B, these additional constraints will be crucial for getting refined sample complexity bounds. Note that Lemma 2.5 immediately implies that is a relaxation of :

{corollary}

[Corollary of Lemma 2.5] for any .

Note also that Constraint 1 in Definition 2.6 ensures that is weaker than and more generally that:

{fact}

For any and , . In particular, for any , .

As a consequence, we conclude the following useful fact about stability of the matrix.

{corollary}

For any , .

Proof.

Take any . By symmetry, it is enough to show that . By Constraint 1, we have that . On the other hand, note that

(18)

where the second step follows from Fact 2.6. The corollary now follows. ∎

Note that if the solution to the convex program were actually integral, that is, some rank-1 matrix for , it would correspond to the direction in which the samples in have the largest discrepancy between the empirical variance and the variance predicted by the empirical mean. Then would correspond to a subset of the domain on which one could filter out bad points as in [jain2019robust]. In the sequel, we will show that this kind of analysis applies even if the solution to is not integral.

3 Filtering Algorithm and Analysis

In this section we prove our main theorem, stated formally below:

{theorem}

Let be an -piecewise degree- distribution over . Then for any smaller than some absolute constant, and any , there is a -time algorithm LearnWithFilter which, given

(19)

-corrupted, -diverse batches of size from , outputs an estimate such that with probability at least over the samples.

In Section 3.1, we first describe and prove guarantees for a basic but important subroutine, 1DFilter, of our algorithm. In Section 3.2, we describe our learning algorithm, LearnWithFilter, in full. In Section 3.3 we define the deterministic conditions that the dataset must satisfy for LearnWithFilter to succeed, deferring the proof that these deterministic conditions hold with high probability (Lemma 3.3) to Appendix A. In Section 3.4 we prove a key geometric lemma (Lemma 3.4). Finally, in Section 3.5, we complete the proof of correctness of LearnWithFilter.

3.1 Univariate Filter

In this section, we define and analyze a simple deterministic subroutine 1DFilter which takes as input a set of weights and a set of scores on the batches , and outputs a new set of weights such that, if the weighted average of the scores among the bad batches exceeds that of the scores among the good batches, then places even less weight relatively on the bad batches than does . This subroutine is given in Algorithm 1 below.

Input: Scores , weights
Output: New weights with even less mass on bad points than good points (see Lemma 3.1)
for all Output
Algorithm 1 1DFilter()
{lemma}

Let be a set of scores, and let be a weight. Given a partition for which

(20)

then the output of 1DFilter satisfies for all , the support of is a strict subset of the support of , and .

Proof.

and are immediate. For , note that

(21)

from which the lemma follows. ∎

We note that this kind of downweighting scheme and its analysis are not new, see e.g. Lemma 4.5 from [charikar2017learning] or Lemma 17 from [steinhardt2018resilience].

3.2 Algorithm Specification

We can now describe our algorithm LearnWithFilter. At a high level, we maintain weights for each of the batches. In every iteration, we compute maximizing . If , then output . Otherwise, update the weights as follows: for every batch , compute the score given by

(22)

and set the weights to be the output of 1DFilter(). The pseudocode for LearnWithFilter is given in Algorithm 2 below.

Input: Frequency vectors coming from an -corrupted, -diverse set of batches from , where is -piecewise, degree
Output: such that , provided uncorrupted samples -good
1 while  do
2       Compute scores according to (22). 1DFilter()
Using the algorithm of [acharya2017sample] (see Lemma 2.4), output the -piecewise, degree- distribution minimizing (up to additive error ).
Algorithm 2 LearnWithFilter()

3.3 Deterministic Condition

{definition}

[-goodness] Take a set of points , and let be a collection of distributions over . For any , define . Denote .

We say is -good if it satisfies that for all for which ,

  1. (Concentration of mean)

    (23)
  2. (Concentration of covariance)

    (24)
  3. (Concentration of variance proxy)

    (25)
  4. (Heterogeneity has negligible effect, see Lemma 3.3)

    (26)
    (27)

We first remark that we only need extremely mild concentration in Condition 3, but it turns out this suffices in the one place where we use it (see Lemma 3.4).

Additionally, note that we can completely ignore Condition 4 when . The following makes clear why it is useful when .

{lemma}

For -good , all of size , and all ,

(28)
(29)
Proof.

For or and any ,

(30)