Improving Dataset Distillation

Improving Dataset Distillation

Ilia Sucholutsky Ilia Sucholutsky University of Waterloo
1Matthias Schonlau University of Waterloo
   Matthias Schonlau Ilia Sucholutsky University of Waterloo
1Matthias Schonlau University of Waterloo
Received: date / Accepted: date

Dataset distillation is a method for reducing dataset sizes: the goal is to learn a small number of synthetic samples containing all the information of a large dataset. This has several benefits: speeding up model training in deep learning, reducing energy consumption, and reducing required storage space. Currently, each synthetic sample is assigned a single ‘hard’ label, which limits the accuracies models trained on distilled datasets can achieve. Also, currently dataset distillation can only be used with image data.

We propose to simultaneously distill both images and their labels, and thus to assign each synthetic sample a ‘soft’ label (a distribution of labels) rather than a single ‘hard’ label. Our improved algorithm increases accuracy by 2-4% over the original dataset distillation algorithm for several image classification tasks. For example, training a LeNet model with just 10 distilled images (one per class) results in over 96% accuracy on the MNIST data. Using ‘soft’ labels also enables distilled datasets to consist of fewer samples than there are classes as each sample can encode information for more than one class. For example, we show that LeNet achieves almost 92% accuracy on MNIST after being trained on just 5 distilled images.

We also propose an extension of the dataset distillation algorithm that allows it to distill sequential datasets including texts. We demonstrate that text distillation outperforms other methods across multiple datasets. For example, we are able to train models to almost their original accuracy on the IMDB sentiment analysis task using just 20 distilled sentences.

Figure 1: 10 distilled MNIST images train networks with fixed initializations from distillation accuracy to . Each image is labelled with its top 3 classes.

1 Introduction

The increase in computational requirements for modern deep learning presents a range of issues. It was recently found that the training of deep learning models has an extremely high energy consumption (Strubell et al., 2019), on top of the already problematic financial cost and time requirement. One path for mitigating these issues is by reducing network sizes. Hinton et al. (2015) proposed knowledge distillation as a method for imbuing smaller, more efficient networks with all the knowledge of their larger counterparts. Instead of decreasing network size, a second path to efficiency may instead be to decrease dataset size. Dataset distillation (DD) has recently been proposed as an alternative formulation of knowledge distillation that aims to do exactly that (Wang et al., 2018).

Dataset distillation is the process of creating a small number of synthetic samples that can quickly train a network to the same accuracy it would achieve if trained on the original dataset. It may seem counter-intuitive that training a model on a small number of synthetic images coming from a completely different distribution than the training data can achieve the original accuracy, but Wang et al. (2018) have shown that for models with known initializations this is indeed feasible; they achieve 94% accuracy on MNIST, a hand-written digit recognition task (LeCun et al., 1998), after training LeNet on just 10 synthetic images.

We propose to improve their already impressive results by learning ‘soft’ labels as a part of the distillation process. The original dataset distillation algorithm uses fixed, or ‘hard’, labels for the synthetic samples (e.g. the ten synthetic MNIST images each have a label corresponding to a different digit). In other words, each label is a one-hot vector: a vector where all entries are set to zero aside from a single entry, the one corresponding to the correct class, which is set to one. We relax this one-hot restriction and make the synthetic labels learnable. The resulting distilled labels are thus similar to those used for knowledge distillation as a single image can now correspond to multiple classes. Our soft-label dataset distillation (SLDD) not only achieves over 96% accuracy on MNIST when using 10 distilled images, a 2% increase over the state-of-the-art (SOTA), but also achieves almost 92% accuracy with just 5 distilled images, which is less than 1 image per class. In addition to soft labels, we also extend dataset distillation to the natural language/sequence modelling domain. Text Dataset Distillation (TDD) can train a convolutional neural network (CNN) (LeCun et al., 1999) with known initialization up to 91% of its original accuracy on the IMDB sentiment classification task (Maas et al., 2011) using just 20 synthetic sentences. Finally, we revisit the linear regression example discussed by Wang et al. (2018) and derive a new lower-bound for the number of samples required to train a model to original accuracy depending on whether distilled labels are ‘hard’ and static, or ‘soft’ and learnable.

The rest of this work is divided into 4 sections. In Section 2, we discuss related work in the fields of knowledge distillation, dataset reduction, and example generation. In Section 3, we propose improvements and extensions to dataset distillation and associated theory. In Section 4, we empirically validate SLDD in a wide-range of experiments. Finally, in Section 5, we discuss the significance of SLDD and our outlook for the future.

2 Related Work

2.1 Knowledge Distillation

Dataset distillation was originally inspired by network distillation (Hinton et al., 2015) which is a form of knowledge distillation or model compression (Buciluǎ et al., 2006) that has been studied in various contexts including when working with sequential data (Kim and Rush, 2016). Network distillation aims to distill the knowledge of large, or even multiple, networks into a smaller network. Similarly, dataset distillation aims to distill the knowledge of large, or even multiple, datasets into a small number of synthetic samples. ‘Soft’ labels were recently proposed as an effective way of distilling networks by feeding the output probabilities of a larger network directly to a smaller network (Hinton et al., 2015), and have previously been studied in the context of different machine learning algorithms (El Gayar et al., 2006). Our soft-label dataset distillation (SLDD) algorithm also uses ‘soft’ labels but these are persistent and learned over the training phase of a network (rather than being produced during the inference phase as in the case of network distillation).

2.2 Dataset Reduction

There are a large number of methods that aim to reduce the size of a dataset with varying objectives. Active learning aims to reduce the required size of the labelled portion of a dataset by only labelling examples that are determined to be most important (Cohn et al., 1996; Tong and Koller, 2001). Several methods aim to ‘prune’ a dataset, or create a ‘core-set’, by leaving in only examples that are determined to be useful (Angelova et al., 2005; Bachem et al., 2017; Sener and Savarese, 2017; Tsang et al., 2005). In general, all of these methods use samples from the true distribution, typically subsets of the original training set. By lifting this restriction and instead learning synthetic samples, dataset distillation requires far fewer samples to distill the same amount of knowledge.

2.3 Generative Adversarial Networks

Generative Adversarial Networks (GANs) have recently become a very widely used method for image generation and are primarily used to produce images that closely mimic those coming from the true distribution (Ledig et al., 2017; Goodfellow et al., 2014; Choi et al., 2018; Radford et al., 2015). With dataset distillation we instead set knowledge distillation as the objective but do not attempt to produce samples from the true distribution. Using the generator from a trained GAN may be a much faster way of producing images than the gradient-based method employed by dataset distillation. However, since the number of distilled images we aim to produce is very small, solving the objective directly through gradient-based optimization is sufficiently fast, while also more straightforward. Additionally, while some GANs can work with text (Reed et al., 2016; Yu et al., 2017), they are primarily intended for image generation.

2.4 Measuring Problem Dimensionality

We may intuitively believe that one deep learning task is more difficult than another. For example, when comparing the digit recognition task MNIST, to the image classification task CIFAR10 (Krizhevsky et al., 2009), it seems that CIFAR is the tougher problem, but it is hard to determine to what extent it is tougher. It is possible to try to quantify what exactly it means for one problem to be more difficult than other. One approach is to compare state-of-the-art (SOTA) results on datasets. For example, the near-SOTA ‘dropconnect’ model on MNIST achieves a 0.21% error rate, while on CIFAR10 it achieves an error rate of 9.32% (Wan et al., 2013). However, this approach reveals increasingly little as deeper networks approach perfect accuracy on multiple tasks. Li et al. (2018) instead derive a more model-independent metric for comparing the dimensionality of various problems based on the minimum number of learnable parameters needed to achieve a good local optimum. Similarly, dataset distillation aims to find the minimum number of synthetic samples needed to achieve a good local optimum. The difference is that  Li et al. (2018) constrain the number of searchable dimensions within the network weight space, while dataset distillation constrains them within the data space.

3 Improving Dataset Distillation

3.1 Basic Approach

Our underlying approach is the one proposed by Wang et al. (2018). We summarize it here in a slightly modified way that explicitly shows the labels of the distilled dataset. This additional notation becomes useful once we enable label learning in the next section.

Given a training dataset , a neural network with parameters and, a twice-differentiable loss function , our objective is to find


In general, training with stochastic gradient descent (SGD) involves repeatedly sampling minibatches of training data and updating network parameters by their error gradient scaled by learning rate .


With dataset distillation, the goal is to perform just one such step while still achieving the same accuracy. We do this by learning a very small number of synthetic samples that minimize , a one-step loss objective, for .


Note that, currently, we are minimizing over , but not , as the distilled labels are fixed for the original dataset distillation algorithm. We minimize this objective, or in other words ‘learn the distilled samples’, by using standard gradient descent.

3.2 Learnable Labels

As mentioned above, one formulation of knowledge distillation proposes that a smaller network be trained on the outputs of a larger network rather than the original training labels. Unlike the training labels, the output labels are not ‘hard’ labels. Because they are outputs of a softmax layer, the output labels form a probability distribution over the possible classes. The idea is that any training image actually contains information about more than one class (e.g. an image of the digit ‘3’ looks a lot like other digits ‘3’ but it also looks like the digit ‘8’). Using ‘soft’ labels allows us to convey more information about the associated image.

The original dataset distillation algorithm was restricted to ‘hard’ labels for the distilled data; each distilled image has to be associated with just a single class. We relax this restriction and allow distilled labels to take on any real value. Since the distilled labels are now continuous variables, we can modify the distillation algorithm in order to make the distilled labels learnable using the same method as for the distilled images: a combination of backpropagation and gradient descent. With our modified notation, we simply need to change equation (4) to also minimize over .


Algorithm 1 details this soft-label dataset distillation (SLDD) algorithm. We note that in our experiments, we generally initialize with the one-hot values that ‘hard’ labels would have. We found that this tends to increase accuracy when compared to random initialization, perhaps because it encourages more differentiation between classes early on in the distillation process.

Input: : distribution of initial weights; : the number of distilled data; : step size; : batch size; : number of optimization iterations; : initial value for ; : initial value for

1:  Initialize distilled data randomly, ,
2:  for each training step t = 1 to T do
3:     Get a minibatch of real training data
4:     One-hot encode the labels
5:     Sample a batch of initial weights
6:     for each sampled  do
7:        Compute updated model parameter with GD
8:        Evaluate the objective function on real training data:
9:     end for
10:     Update distilled data and
11:  end for

Output: distilled data ; distilled labels ; optimized learning rate

Algorithm 1 Soft-Label Dataset Distillation (SLDD)

3.3 Text and Other Sequences

The original dataset distillation algorithm was only shown to work with image data, but intuitively, there is no reason why text or other sequences should not be similarly distillable. However, it is difficult to use gradient methods directly on text data as it is discrete. In order to be able to use SLDD with text data we need to first embed the text data into a continuous space. Any popular embedding method can be used, but in our experiments we used pre-trained GloVe embeddings (Pennington et al., 2014). Once the text is embedded into a continuous space, the problem of distilling it becomes analogous to image distillation. In fact, if all sentences are padded/truncated to some pre-determined length (we use 400 in our experiments), then each sentence is essentially just a one-channel image of size [length][embedding dimension]. The resulting algorithm for text dataset distillation (TDD) is detailed in Algorithm 2. It is important to note that the embedding is performed only on sentences coming from the true dataset; the distilled samples are learned directly as embedded representations.

Input: : distribution of initial weights; : the number of distilled data; : step size; : batch size; : number of optimization iterations; : initial value for ; : initial value for ; : sentence length; : embedding size

1:  Initialize distilled data randomly of size , ,
2:  for each training step t = 1 to T do
3:     Get a minibatch of real training data
4:     Pad (or truncate) each sentence in the minibatch
5:     Embed each sentence in the minibatch
6:     One-hot encode the labels
7:     Sample a batch of initial weights
8:     for each sampled  do
9:        Compute updated model parameter with GD
10:        Evaluate the objective function on real training data:
11:     end for
12:     Update distilled data and
13:  end for

Output: distilled data ; distilled labels ; optimized learning rate

Algorithm 2 Text Dataset Distillation (TDD)

3.4 Random initializations and multiple steps

The procedures we described above make one important assumption: that network initialization is fixed. The samples created this way do not lead to high accuracies when the network is re-trained on them with a different initialization as they contain information not only about the dataset, but also about . In the distilled images in Figures 1 and  2, this can be seen as what looks like a lot of random noise. Wang et al. (2018) propose that the method instead be generalized to work with network initializations randomly sampled from some restricted distribution.


The resulting images, especially for MNIST, appear to have much clearer patterns and much less random noise, and the results detailed in Section 4 suggest that this method generalizes fairly well to other randomly sampled initializations from the same distribution.

Additionally, Wang et al. (2018) suggest that the above methods can work with multiple gradient descent steps. If we want to perform multiple gradient descent steps, each with a different minibatch of distilled data, we simply need to backpropagate the gradient through every one of these additional steps. Finally, it may also be beneficial to train the neural networks on the distilled data for more than one epoch. The experimental results suggest that multiple steps and multiple epochs improve distillation performance for both image and text data, particularly when using random network initializations.

3.5 Analysis of Simple Linear Case

We revisit the linear regression case study from Wang et al. (2018) but examine the impact of using ‘hard’ versus ‘soft’ labels.. The goal of this case study is to derive a lower-bound for the number of distilled samples required to achieve original accuracy with arbitrary random model initializations. The model used in this case study is linear regression with a quadratic loss.

We have: data matrix and an target matrix

Given a weight matrix we also have

Solving for: an matrix, an , and the learning rate, that would minimize

After training our model for a single step of gradient descent we would have


Of course, for linear regression, the global minima are achieved at


Plugging equation 7 into equation 8, we get


At this point, Wang et al. (2018) make the assumption that the feature columns of are independent. We believe that this is a fairly restrictive assumption as features in datasets often do tend to have some dependence. For example, pixels in an image are very likely to be spatially correlated. We propose to slightly rework this assumption to make it a lot milder. We instead assume that the user can perform some sort of feature selection or feature engineering such that the resulting dataset has feature columns which are all linearly independent. This modification does not change any of the intermediate steps, but it does mean that the actual lower-bound on the number of distilled samples required to achieve original accuracy is dependent on the number of linearly independent features of a dataset rather than the total number of features. The outcome of this assumption is still that has full rank (and as a result is also invertible).

If we want equation 9 to hold for all arbitrary , then we must have the following.

has full rank meaning (the lower-bound discussed above)

We want the smallest possible distilled dataset, so we set

is square and orthogonal


Since is orthogonal this is geometrically equivalent to rotating vector and scaling it by . As a result, we see that there will always be a solution to this set of equations. However, because of the orthogonality, we must have that

Case 1 - ‘hard’, fixed labels ( is fixed):

Since is fixed, so is by the equation above. In fact, because of the restrictions on we have that so . This means that depending on the dataset, we may need to have very high learning rates if we are using ‘hard’, fixed labels.

Case 2 - ‘soft’, learnable labels ():

This is trivially solvable. For example set , , . In fact, the learning rate can be arbitrarily low or high, as can be scaled accordingly.

In Case 1, we examine the effects of ‘hard’ labels by setting all the elements of to either one or zero, thereby simulating binary labels. Also, is fixed ahead of time since it is not learnable in this scenario. In Case 2, we let the elements of take any real values, and also is learnable. Clearly, in both cases it is possible to learn a distilled dataset that will train our linear model to its original accuracy. However, by using ‘soft’, learnable labels, we add degrees of freedom that remove restrictions from the other learnable parameters. This allows us to tune those parameters for other desirable qualities. For example, we can use an arbitrarily small learning rate which is important when we work with batches of data, instead of the entire dataset at once as in this linear regression example. More generally, it appears that having additional degrees of freedom makes it easier to find good solutions the same way that having more parameters in a model does.

We also note that with our modified assumption about linear independence of features, the lower-bound on the number of distilled samples needed to achieve original accuracy is where is now the number of linearly independent features rather than the total number of features. This lower-bound is more inline with the empirical results in the Experiments section that show very small distilled datasets, much smaller than the total number of features, achieving quite close to original accuracy.

4 Experiments

4.1 Metrics

The simplest metric for gauging distillation performance is to train a model on distilled samples and then test it on real samples. We refer to the accuracy achieved on these real samples as the ‘distillation accuracy’. However, several of the models we use in our experiments do not achieve SOTA accuracy on the datasets they are paired with, so it is useful to construct a relative metric that compares distillation accuracy to original accuracy. The first such metric is the ‘distillation ratio’ which we define as the ratio of distillation accuracy to original accuracy. The distillation ratio is heavily dependent on the number of distilled samples so the notation we use is . We may refer to this metric as the ‘-sample distillation ratio’ when clarification is needed. It may also be of interest to find the minimum number of distilled images required to achieve a certain distillation ratio. To this end we define a second relative metric that we call the ‘% distillation size’, and we write where is the minimum number of distilled samples required to achieve a distillation ratio of %.

4.2 Image Data

The LeNet model we use with MNIST achieves nearly SOTA results, 99% accuracy, so it is sufficient to use distillation accuracy when describing distillation performance with it. However, AlexCifarNet only achieves 80% on CIFAR10 so it is helpful to use the 2 relative metrics when describing this set of distillation results.


For image data, we use the baseline results as determined by Wang et al. (2018). These results are shown in Table 2.

Fixed initialization

When the network initialization is kept fixed between the distillation and training phases, dataset distillation produces synthetic images that result in very high distillation accuracies. The SLDD algorithm produces images that result in equal or higher accuracies when compared to the original DD algorithm. For example, DD can produce 10 distilled images that train a LeNet model up to 93.76% accuracy on MNIST. Meanwhile, SLDD can produce 10 distilled images, seen in Figure 1 that train the same model up to 96.13% accuracy. The full distilled labels for these 10 images are laid out in Table 1. SLDD can even learn a tiny set of just 5 distilled images that train the same model to 91.56% accuracy. As can be seen in Figure 3, adding more distilled images typically increases distillation accuracy, but this begins to plateau farily quickly. Similarly, SLDD provides a 7.5% increase in 100-sample distillation ratio (6% increase in distillation accuracy) on CIFAR10 over DD. Based on these results, detailed further in Table 3, it appears that SLDD is even more effective than DD at distilling image data into a small number of samples. This intuitively makes sense as the learnable labels used by SLDD increase the capacity of the distilled dataset for storing information.

Distilled 0 1 2 3 4 5 6 7 8 9
1 2.34 -0.33 0.23 0.04 -0.03 -0.23 -0.32 0.54 -0.39 0.49
2 -0.17 2.58 0.32 0.37 -0.68 -0.19 -0.75 0.53 0.27 -0.89
3 -0.26 -0.35 2.00 0.07 0.08 0.42 0.02 -0.08 -1.09 0.10
4 -0.28 0.04 0.59 2.08 -0.61 -1.11 0.52 0.19 -0.20 0.32
5 -0.11 -0.52 -0.08 0.90 2.63 -0.44 -0.72 -0.39 -0.29 0.87
6 0.25 -0.20 -0.19 0.51 -0.02 2.47 0.62 -0.42 -0.52 -0.63
7 0.42 0.55 -0.09 -1.07 0.83 -0.19 2.16 -0.30 0.26 -0.91
8 0.18 -0.33 -0.25 0.06 -0.91 0.55 -1.17 2.11 0.94 0.47
9 0.46 -0.48 0.24 0.09 -0.78 0.75 0.47 -0.40 2.45 -0.71
10 -0.53 0.52 -0.74 -1.32 1.03 0.23 0.05 0.55 0.31 2.45
Table 1: Learned distilled labels for the 10 distilled MNIST images in Figure 1
\thesubsubfigure Step 0
\thesubsubfigure Step 5
\thesubsubfigure Step 9
Figure 2: 100 distilled CIFAR10 images train networks with fixed initializations from distillation accuracy to . Each image is labelled with its top 3 classes. Only 3 of the 10 steps are shown.

Random initialization

It is also of interest to know whether distilled data stores information only about the chosen network initialization, or whether it can actually store knowledge contained within the training data. To this end, we perform experiments by sampling random network initializations generated using the Xavier Initialization (Glorot and Bengio, 2010). The distilled images produced in this way are more representative of the training data, but generally result in lower accuracies when models are trained on them. Once again, images distilled using SLDD lead to higher distillation accuracies than DD when the number of distilled images is held constant. For example, 100 MNIST images learned by DD result in accuracies of 79.5 8.1%, while 100 images learned by SLDD result in accuracies of 82.75 2.75%. There is similarly a 3.8% increase in 100-sample distillation ratio (3% increase in distillation accuracy) when using SLDD instead of DD on CIFAR10 using 100 distilled images each. It is also interesting to note that the actual distilled images, as seen in Figures 4 and 5, appear to have much clearer patterns emerging than in the fixed initialization case. These results suggest that DD, and even more so SLDD, can be generalized to work with random initializations and distill knowledge about the dataset itself when they are trained this way. All the mean and standard deviation results for random initializations in Table 3 are derived by testing with 200 randomly initialized networks.

Figure 3: Distillation accuracy on MNIST with LeNet for different distilled dataset sizes.
\thesubsubfigure Step 0
\thesubsubfigure Step 5
\thesubsubfigure Step 9
Figure 4: 100 distilled MNIST images train networks with random initializations from distillation accuracy to . Each image is labelled with its top 3 classes. Only 3 of the 10 steps are shown.
\thesubsubfigure Step 0
\thesubsubfigure Step 5
\thesubsubfigure Step 9
Figure 5: 100 distilled CIFAR10 images train networks with random initializations from distillation accuracy to . Each image is labelled with its top 3 classes. Only 3 of the 10 steps are shown.
Used as training data in same # of GD steps
Rand. real Optim. real -means Avg. real
MNIST 68.6 9.8 73.0 7.6 76.4 9.5 77.1 2.7
CIFAR10 21.3 1.5 23.4 1.3 22.5 3.1 22.3 0.3
Used in K-NN
Rand. real k-means
MNIST 71.5 2.1 92.2 0.1
CIFAR10 18.8 1.3 29.4 0.3
Table 2: Baseline results on image data
Fixed Random Fixed Random
MNIST 98.6 82.7 2.75 96.6 79.5 8.1
CIFAR10 60.0 39.8 0.83 54.0 36.8 1.2
Table 3: Distillation accuracy with image data

4.3 Text Data

As mentioned above, TDD does not work in the space of the original raw data, but rather produces synthetic samples from the embedding space. Because each distilled ‘sentence’ is actually a matrix, the embedding layer should not be a part of the model but instead needs to be used as a pre-processing step that is applied only to real sentences from the training data. For text experiments, we use the IMDB sentiment analysis dataset, the Stanford Sentiment Treebank 5-class task (SST5) (Socher et al., 2013), and the Text Retrieval Conference question classification tasks with 6 (TREC6) and 50 (TREC50) classes (Voorhees and Tice, 1999). The text experiments are performed with a fairly shallow but wide CNN model. This model does not achieve SOTA accuracies on the text datasets we use in distillation experiments, so the model’s original accuracy on each dataset is detailed in Table 4.

87.1 42.3 89.6 84.4
Table 4: CNN model accuracies on text datasets


We consider the same four baselines but generalize them slightly so that they work for our text experiments. This means that we perform baseline experiments in the embedding space.

  • Random real sentences: We randomly sample the same number of real sentences per class, pad/truncate them, and look up their embeddings.

  • Optimized real sentences: We sample and pre-process different sets of random real sentences as above, but now we choose the 20% of the sets that have the best performance.

  • -means: First, we pre-process the sentences. Then, we use -means to learn clusters for each class, and use the resulting centroids to train.

  • Average real sentences: First, we pre-process the sentences. Then, we compute the average embedded matrix for each class and use it for training.

The baseline results are shown in Table 5.

Fixed initialization

When the network initialization is kept fixed between the distillation and training phases, dataset distillation can also produce synthetic text that results in very high model accuracies. For example, TDD can produce 20 distilled sentences that train the CNN model up to a distillation ratio of 91.62% on the IMDB dataset. Even for far more difficult language tasks, TDD still has impressive results but with larger distilled datasets. For example, for the 50-class TREC50 task, it can learn 1000 sentences that train the model to a distillation ratio of 79.86%. Some examples of TDD results are detailed in Table 6.

Random initialization

Curiously, TDD has much less of a performance drop when using random initialization than SLDD does. TDD can learn 20 images that train a network with random initialization up to a distillation ratio of 84.27% on IMDB, only slightly lower than in the fixed initialization case. Similarly, for TREC6 there is only a 2.07% difference between the 60-sample distillation ratio for fixed and random initializations; the accuracies were within a standard deviation of each other. Only in the case of the much tougher TREC50 task, was there a large decrease in performance when working with random initializations. All the mean and standard deviation results for random initializations in 6 are derived by testing with 200 randomly initialized networks.

Used as training data in same # of GD steps
Rand. real Optim. real -means Avg. real
IMDB 49.71 0.88 49.85 0.77 49.85 0.58 50.00 0.14
SST5 21.16 4.92 24.64 2.59 19.57 4.52 21.29 4.11
TREC6 37.53 10.08 44.60 7.52 34.44 12.98 27.98 9.53
TREC50 8.17 5.96 9.86 6.56 14.74 5.48 12.49 6.39
Used in K-NN
Rand. real k-means
IMDB 49.98 0.13 50.00 0.00
SST5 23.08 0.01 20.85 2.1
TREC6 31.52 9.87 50.5 6.8
TREC50 15.4 5.08 45.06 6.59
Table 5: Baseline results on text data. Each baseline uses 10 GD steps aside from IMDB with -means and the full TREC50 row which had to be done with 2 GD steps due to GPU memory constraints and also insufficient training samples for some classes in TREC50.
Fixed Random
IMDB 75.032 73.3978 3.28
SST5 37.4661 36.3023 1.45
TREC6 79.2 77.335 2.93
TREC50 57.6 11 0.00
TREC50 67.4 42.102 2.09
Table 6: Distillation accuracy with text data. Each result is based on 10 GD steps with 1 image per class, except for the second TREC50 row which is based on 5 GD steps with 4 images per class. Standard deviations are included for experiments with random initializations.

5 Conclusion

We have improved dataset distillation performance by enabling learnable distilled labels, and we have greatly increased the space of datasets with which it can be used. We have also demonstrated these improvements empirically on a number of different datasets.

However, even with SLDD and TDD, there are still some limitations to dataset distillation. Network initializations are still limited to coming from the same distribution, and no testing has yet been done on whether a single distilled dataset can be used to train networks with different architectures. Further investigations are needed to determine how dataset distillation can be generalized to work with more variation in initializations, and even across networks with different architectures. Also, we have so far only shown dataset distillation working with CNNs. Fortunately, there is nothing in the dataset distillation algorithm that limits it to this network type. As long as a network has a twice-differentiable loss function and the gradient can be back-propagated all the way to the inputs, then that network is compatible with dataset distillation. A natural extension of TDD is to enable it to be used with recurrent neural networks (RNN), such as Long Short-Term Memory (LSTM) networks, as these are commonly used for tasks in the natural language domain (Hochreiter and Schmidhuber, 1997). We believe the best way to do this will be to simply use backpropagation through time (BPTT) (Werbos and others, 1990) to backpropagate the error gradients.

One direction we think is worth exploring at a deeper level is whether the minimum distillation size of a dataset corresponds well to task difficulty, and if this can be used as a tool for comparing machine learning tasks. Another direction we believe may be promising is to use distilled datasets for speeding up Neural Architecture Search and other very compute-intensive meta-algorithms. If distilled datasets are a good proxy for performance evaluation, they can reduce search times by multiple orders of magnitude. In general, dataset distillation is an exciting new direction of knowledge distillation; improvements may help us not only better understand our datasets, but also enable a number of applications revolving around efficient machine learning.

We would like to thank Dr. Sebastian Fischmeister for providing us with the computational resources that enabled us to perform many of the experiments found in this work.

Conflict of interest

The authors declare that they have no conflict of interest.


  • A. Angelova, Y. Abu-Mostafam, and P. Perona (2005) Pruning training sets for learning of object categories. In 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’05), Vol. 1, pp. 494–501. Cited by: §2.2.
  • O. Bachem, M. Lucic, and A. Krause (2017) Practical coreset constructions for machine learning. arXiv preprint arXiv:1703.06476. Cited by: §2.2.
  • C. Buciluǎ, R. Caruana, and A. Niculescu-Mizil (2006) Model compression. In Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 535–541. Cited by: §2.1.
  • Y. Choi, M. Choi, M. Kim, J. Ha, S. Kim, and J. Choo (2018) Stargan: unified generative adversarial networks for multi-domain image-to-image translation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8789–8797. Cited by: §2.3.
  • D. A. Cohn, Z. Ghahramani, and M. I. Jordan (1996) Active learning with statistical models. Journal of Artificial Intelligence Research 4, pp. 129–145. Cited by: §2.2.
  • N. El Gayar, F. Schwenker, and G. Palm (2006) A study of the robustness of knn classifiers trained using soft labels. In IAPR Workshop on Artificial Neural Networks in Pattern Recognition, pp. 67–80. Cited by: §2.1.
  • X. Glorot and Y. Bengio (2010) Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249–256. Cited by: §4.2.
  • I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial nets. In Advances in Neural Information Processing Systems, pp. 2672–2680. Cited by: §2.3.
  • G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §1, §2.1.
  • S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural Computation 9 (8), pp. 1735–1780. Cited by: §5.
  • Y. Kim and A. M. Rush (2016) Sequence-level knowledge distillation. arXiv preprint arXiv:1606.07947. Cited by: §2.1.
  • A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Technical report Citeseer. Cited by: §2.4.
  • Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, et al. (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §1.
  • Y. LeCun, P. Haffner, L. Bottou, and Y. Bengio (1999) Object recognition with gradient-based learning. Shape, Contour and Grouping in Computer Vision, pp. 823–823. Cited by: §1.
  • C. Ledig, L. Theis, F. Huszár, J. Caballero, A. Cunningham, A. Acosta, A. Aitken, A. Tejani, J. Totz, Z. Wang, et al. (2017) Photo-realistic single image super-resolution using a generative adversarial network. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4681–4690. Cited by: §2.3.
  • C. Li, H. Farkhoor, R. Liu, and J. Yosinski (2018) Measuring the intrinsic dimension of objective landscapes. arXiv preprint arXiv:1804.08838. Cited by: §2.4.
  • A. L. Maas, R. E. Daly, P. T. Pham, D. Huang, A. Y. Ng, and C. Potts (2011) Learning word vectors for sentiment analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, Portland, Oregon, USA, pp. 142–150. Cited by: §1.
  • J. Pennington, R. Socher, and C. Manning (2014) Glove: global vectors for word representation. In Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP), pp. 1532–1543. Cited by: §3.3.
  • A. Radford, L. Metz, and S. Chintala (2015) Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. Cited by: §2.3.
  • S. Reed, Z. Akata, X. Yan, L. Logeswaran, B. Schiele, and H. Lee (2016) Generative adversarial text to image synthesis. arXiv preprint arXiv:1605.05396. Cited by: §2.3.
  • O. Sener and S. Savarese (2017) Active learning for convolutional neural networks: a core-set approach. arXiv preprint arXiv:1708.00489. Cited by: §2.2.
  • R. Socher, A. Perelygin, J. Wu, J. Chuang, C. D. Manning, A. Ng, and C. Potts (2013) Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 conference on empirical methods in natural language processing, pp. 1631–1642. Cited by: §4.3.
  • E. Strubell, A. Ganesh, and A. McCallum (2019) Energy and policy considerations for deep learning in NLP. arXiv preprint arXiv:1906.02243. Cited by: §1.
  • S. Tong and D. Koller (2001) Support vector machine active learning with applications to text classification. Journal of Machine Learning Research 2 (Nov), pp. 45–66. Cited by: §2.2.
  • I. W. Tsang, J. T. Kwok, and P. Cheung (2005) Core vector machines: fast SVM training on very large data sets. Journal of Machine Learning Research 6 (Apr), pp. 363–392. Cited by: §2.2.
  • E. M. Voorhees and D. M. Tice (1999) The trec-8 question answering track evaluation. In TREC, Vol. 1999, pp. 82. Cited by: §4.3.
  • L. Wan, M. Zeiler, S. Zhang, Y. Le Cun, and R. Fergus (2013) Regularization of neural networks using dropconnect. In International conference on machine learning, pp. 1058–1066. Cited by: §2.4.
  • T. Wang, J. Zhu, A. Torralba, and A. A. Efros (2018) Dataset distillation. arXiv preprint arXiv:1811.10959. Cited by: §1, §1, §1, §3.1, §3.4, §3.4, §3.5, §3.5, §4.2.
  • P. J. Werbos et al. (1990) Backpropagation through time: what it does and how to do it. Proceedings of the IEEE 78 (10), pp. 1550–1560. Cited by: §5.
  • L. Yu, W. Zhang, J. Wang, and Y. Yu (2017) Seqgan: sequence generative adversarial nets with policy gradient. In AAAI-17: Thirty-First AAAI Conference on Artificial Intelligence, Vol. 31, pp. 2852–2858. Cited by: §2.3.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description