Set Transformer
Abstract
Many machine learning tasks such as multiple instance learning, 3D shape recognition and fewshot image classification are defined on sets of instances. Since solutions to such problems do not depend on the permutation of elements of the set, models used to address them should be permutation invariant. We present an attentionbased neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces computation time of selfattention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating increased performance compared to recent methods for setstructured data.
Set Transformer
1 Introduction
Learning representations has proven to be an essential problem for deep learning and its many success stories. The majority of problems tackled by deep learning are instancebased and take the form of mapping a fixeddimensional input tensor to its corresponding target value (Krizhevsky et al., 2012; Graves et al., 2013). For some applications, we are required to process setstructured data. Multiple instance learning (Dietterich et al., 1997; Maron & LozanoPérez, 1998) is an example of such a setinput problem, where a set of instances is given as an input and the corresponding target is a label for the entire set. Other problems such as 3D shape recognition (Wu et al., 2015; Shi et al., 2015; Su et al., 2015; Charles et al., 2017), sequence ordering (Vinyals et al., 2016), and various set operations (Muandet et al., 2012; Edwards & Storkey, 2017; Zaheer et al., 2017) can also be viewed as such setinput problems. Moreover, many metalearning (Thrun & Pratt, 1998; Schmidhuber, 1987) problems which learn using a set of different but related tasks may also be treated as setinput tasks where an input set corresponds to the training dataset of a single task. For example, fewshot image classification (Finn et al., 2017; Snell et al., 2017; Lee & Choi, 2018) operates by building a classifier using a support set of images, which is evaluated with query images.
A model for setinput problems should satisfy two critical requirements. First, it should be permutation invariant — the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size. While these requirements stem from the definition of a set, they are not easily satisfied in neursalnetworkbased models: classical feedforward neural networks violate both requirements, and RNNs are sensitive to input order.
Recently, Edwards & Storkey (2017) and Zaheer et al. (2017) propose neural network architectures which meet both criteria, which we call set pooling methods. In this model, each element in a set is first independently fed into a feedforward neural network that takes fixedsize inputs. Resulting featurespace embeddings are then aggregated using a pooling operation (, , or similar). The final output is obtained by further nonlinear processing of the aggregated embedding. This remarkably simple architecture satisfies both aforementioned requirements, and more importantly, is proven to be a universal approximator for any set function (Zaheer et al., 2017). Thanks to this property, it is possible to learn complex mapping between input sets and their target outputs in a blackbox fashion, much like with feedforward or recurrent neural networks.
Even though this set pooling approach is theoretically attractive, it remains unclear whether we can approximate complex mappings well using only instancebased feature extractors and simple pooling operations. Since every element in a set is processed independently in a set pooling operation, some information regarding interactions between elements has to be necessarily discarded. This can make some classes of problems unnecessarily difficult to solve.
Consider the problem of metaclustering: we would like to learn a parametric mapping from an input set of points to centers of any clusters in the set, for many such sets. Even though a neural network with a set pooling operation can approximate such a mapping by learning to quantize space, this quantization cannot depend on the contents of the set. It limits the quality of the solution on one hand, and may make optimization of such a model more difficult; we show empirically in Section 4 that it leads to underfitting.
In this paper, we propose a novel setinput deep neural network architecture called the Set Transformer, (cf. Transformer, Vaswani et al. (2017)). The novelty of the Set Transformer comes from three important design choices: 1) We use a selfattention mechanism based on the Transformer to process every element in an input set, which allows our approach to naturally encode pairwise or higherorder interactions between elements in the set. 2) We propose a method to reduce the computation time of Transformers to where is a fixed hyperparameter. 3) We use a selfattention mechanism to aggregate features, which is especially beneficial when the problem of interest requires multiple dependent outputs, such as the problem of metaclustering, where the meaning of each cluster center heavily depends its location relative to the other clusters. We apply the Set Transformer to several setinput problems and empirically demonstrate the importance and effectiveness of these design choices.
This paper is organized as follows. In Section 2, we briefly review the concept of set functions, existing architectures, and the selfattention mechanism. In Section 3, we introduce Set Transformers, our novel neural network architecture for set functions. In Section 4, we present various experiments that demonstrate the benefits of the Set Transformer. We discuss related works in Section 5 and conclude the paper in Section 6.
2 Background
2.1 Pooling Architecture for Sets
Problems involving a set of objects have the permutation invariance property: the target value for a given set is the same regardless of the order of objects in the set. A simple example of a permutation invariant model is a network that performs pooling over embeddings extracted from the elements of a set. More formally,
(1) 
Zaheer et al. (2017) has proven that all permutation invariant functions can be represented as (1) when is the operator and any continuous functions, thus justifying the use of this architecture for setinput problems.
Note that we can deconstruct (1) into two parts: an encoder () which independently acts on each element of a set of items, and a decoder () which aggregates these encoded features and produces our desired output. Most network architectures for setstructured data follow this encoderdecoder structure. Our proposed method is also composed of an encoder and a decoder, but our embedding function does not act independently on each item but considers the whole set to obtain the embedding. Additionally, instead of a fixed function such as , our aggregating function is parameterized and can thus adapt to the problem at hand.
2.2 Attention
Assume we have query vectors (corresponding to points in an input set) each with dimension : . An attention function is a function that maps queries to outputs using keyvalue pairs .
(2) 
The pairwise dot product measures how similar each pair of query and key vectors is, with weights computed with an activation function . The output is a weighted sum of where a value gets more weight if its corresponding key has larger dot product with the query.
Multihead attention, originally introduced in Vaswani et al. (2017), is an extension of the previous attention scheme. Instead of computing a single attention function, this method first projects onto different dimensional vectors, respectively. An attention function () is applied to each of these projections. The output is a linear transformation of the concatenation of all attention outputs:
(3)  
(4) 
Note that has learnable parameters , where , , . A typical choice for the dimension hyperparameters is , , . For brevity, we set and throughout the rest of the paper. Unless specified otherwise, we use the scaled softmax , which our experiments showed worked robustly in most settings.
3 Set transformer
Set operations  Time complexity  Highorder  Permutation 

interactions  invariant  
Recurrent  Yes  No  
Pooling (Zaheer et al., 2017)  No  Yes  
Relational Networks (Santoro et al., 2017)  Yes  Yes  
Set Transformer (SAB + PMA, ours)  Yes  Yes  
Set Transformer (ISAB + PMA, ours)  Yes  Yes 
In this section, we motivate and describe the Set Transformer: an attentionbased neural network architecture that is designed to process sets of data. A Set Transformer consists of an encoder followed by a decoder (cf. Section 2.1). The encoder transforms a set of instances into a set of features, which the decoder transforms into the desired fixeddimensional output.
3.1 Attentionbased set operations
We begin by defining our attentionbased set operations. While existing pooling methods for sets obtain instance features independently of other instances, we use selfattention to concurrently encode the whole set. This gives the Set Transformer the ability to preserve pairwise as well as higherorder interactions among instances during the encoding process. For this purpose, we adapt the multihead attention mechanism used in Transformer. We emphasize that all blocks introduced here are neural network blocks with their own parameters, and not fixed functions.
Given matrices which represent two sets of dimensional vectors, we define the Multihead Attention Block (MAB) with parameters as follows:
(5)  
(6) 
where is any rowwise feedforward layer (i.e. it processes each instance independently and identically), and is layer normalization (Ba et al., 2016). The MAB is an adaptation of the encoder block of the Transformer (Vaswani et al., 2017) without positional encoding and dropout. Using the MAB, we define the Set Attention Block (SAB) as
(7) 
In other words, an SAB takes a set and performs selfattention between the elements in the set, resulting in a set of equal size. Since the output of SAB contains information about pairwise interactions between the elements in the input set , we can stack multiple SABs to encode higher order interactions. Note that while the SAB (7) involves a multihead attention operation (6), where , it could reduce to applying a residual block on . In practice, it learns more complicated functions due to linear projections of inside attention heads, (2) and (4).
A potential problem with using SABs for setstructured data is the quadratic time complexity , which may be too expensive for large sets (). We thus introduce the Induced Set Attention Block (ISAB), which bypasses this problem. Along with the set , additionally define dimensional vectors , which we call inducing points. Inducing points are part of the ISAB itself, and they are trainable parameters which we train along with other parameters of the network. An ISAB with inducing points is defined as:
(8)  
(9) 
The ISAB first transforms into by attending to the input set. The set of transformed inducing points , which contains information about the input set , is again attended to by the input set to finally produce a set of elements. This is analogous to lowrank projection or autoencoder models, where inputs () are first projected onto a lowdimensional object () and then reconstructed to produce outputs. The difference is that the goal of these methods is reconstruction whereas ISAB aims to obtain good features for the final task. We expect the learned inducing points to encode some global structure which helps explain the inputs . As an example, think of a clustering problem on a 2D plane. The inducing points could be appropriately distributed points on the 2D plane so that the encoder can compare elements in the query dataset indirectly through their proximity to these grid points.
Note that in (8) and (9), attention was computed between a set of size and a set of size . Therefore, the time complexity of is where is a hyperparameter — an improvement over the quadratic complexity of the SAB. We compare characteristics of various set operations in Table 1. We also emphasize that both of our set operations are permutation equivariant:
Definition 1.
We say a function is permutation equivariant iff for any permutation , . Here is the set of all permutations of indices .
Property 1.
Both and are permutation equivariant.
3.2 Encoder
Using the SAB and ISAB defined above, we construct the encoder of the Set Transformer by stacking multiple SABs or multiple ISABs, for example:
(10)  
(11) 
We point out again that the time complexity for stacks of SABs and ISABs are and , respectively. This can result in much lower processing times when using ISAB (as compared to SAB), while still maintaining high representational power.
3.3 Decoder
After the encoder transforms data into features , the decoder aggregates them into a single vector which is fed into a feedforward network to get final outputs. A common aggregation scheme is to simply take the average or dimensionwise maximum of the feature vectors (cf. Section 1). We instead aggregate features by applying multihead attention on a learnable set of seed vectors . We call this scheme Pooling by Multihead Attention (PMA):
(12)  
(13) 
Note that the output of is a set of items. In most cases, using one seed vector () and no SAB sufficed. However, when the problem of interest requires correlated outputs, the natural thing to do is to use inducing points. An example of such a problem is clustering where the desired output is centers. In this case, the additional SAB was crucial because it allowed the network to directly take the correlation between the pooled features into account. Intuitively, feature aggregation using attention should be beneficial because the influence of each instance on the target is not necessarily equal. For example, consider a problem where the target value is the maximum value of a set of real numbers. Since the target can be recovered using only a single instance (the largest), finding and attending to that instance during aggregation will be advantageous. In the next subsection, we further analyze both the encoder and decoder structures more rigorously.
3.4 Analysis
Since the blocks used to construct the encoder (i.e., SAB, ISAB) are permutation equivariant, the mapping of the encoder is permutation equivariant as well. Combined with the fact that the PMA in the decoder is a permutation invariant transformation, we have the following:
Proposition 1.
The Set Transformer is permutation invariant.
Being able to approximate any function is a desirable property, especially for blackbox models such as deep neural networks. Building on previous results about the universal approximation of permutation invariant functions, we prove the universality of Set Transformers:
Proposition 2.
The Set Transformer is a universal approximator of permutation invariant functions.
Proof.
See Appendix A. ∎
4 Experiments
To evaluate the Set Transformer, we apply it to a suite of tasks involving sets of data points. We repeat all experiments five times and report performance metrics evaluated on corresponding test datasets. We compared various architectures arising from the combination of the choices of having attention in encoders and decoders, each of which roughly represents existing works as its special cases. Unless specified otherwise, ”simple pooling” means average pooling.

rFF + Pooling ( Zaheer et al. (2017)): rFF layers in encoder and simple pooling + rFF layers in decoder.

rFF + PMA (includes Ilse et al. (2018) as special cases): rFF layers in encoder and PMA (followed by stack of SABs) in decoder.

Set Transformer: Stack of SABs (ISABs) in encoder and PMA (followed by stack of SABs) in decoder.
4.1 Toy Problem: Maximum Value Regression
To demonstrate the advantage of attentionbased set aggregation over simple pooling operations, we consider a toy problem: regression to the maximum value of a given set. Given a set of real numbers , the goal is to return . Given prediction , we use the mean absolute error as the loss function. We constructed simple pooling architectures with three different pooling operations: , n, and . We report loss values after training in Table LABEL:table:max. Mean and sumpooling architectures result in a high mean absolute error (MAE). The model with maxpooling can predict the output perfectly by learning its encoder to be an identity function, and thus achieves the highest performance. Notably, the Set Transformer achieves performance comparable to the maxpooling model, which underlines the importance of additional flexibility granted by attention mechanisms — it can learn to find and attend to the maximum element.
Architecture  MAE 

rFF + Pooling (mean)  2.133 0.190 
rFF + Pooling (sum)  1.902 0.137 
rFF + Pooling (max)  0.1355 0.0074 
Set Transformer  0.2085 0.0127 
Architecture  Error 

rFF + Pooling  0.5618 0.0072 
rFF + PMA  0.5428 0.0076 
SAB + Pooling  0.4477 0.0077 
Set Transformer  0.4178 0.0075 
4.2 Counting Unique Characters
In order to test the ability of modelling interactions between objects in a set, we introduce a new task of counting unique elements in an input set. We use the Omniglot (Lake et al., 2015) dataset, which consists of 1,623 different handwritten characters from various alphabets, where each character is represented by 20 different images.
We split all characters (and corresponding images) into train, validation, and test sets and only train using images from the train character classes. We generate input sets by sampling between 6 and 10 images and we train the model to predict the number of different characters inside the set. We used a Poisson regression model to predict this number, with the rate given as the output of a neural network. We maximized the log likelihood of this model using stochastic gradient ascent.
We evaluated model performance using sets of images sampled from the test set of characters. LABEL:table:unique reports accuracy, measured as the frequency at which the mode of the Poisson distribution chosen by the network is equal to the number of characters inside the input set.
4.3 Solving Maximum Likelihood Problems for Mixture of Gaussians
We applied the setinput networks to the task of maximum likelihood of mixture of Gaussians (MoGs). The loglikelihood of a dataset generated from an MoG with components is
(14) 
The goal is to learn the optimal parameters . The typical approach to this problem is to run an iterative algorithm such as ExpectationMaximisation (EM) until convergence. Instead, we aim to learn a generic metaalgorithm that directly maps the input set to . One can also view this as amortized maximum likelihood learning. Specifically, given a dataset , we train a neural network to output parameters which maximize
(15) 
We structured as a setinput neural network and learned its parameters using stochastic gradient ascent, where we approximate gradients using minibatches of datasets.
We tested Set Transformers along with other setinput networks on two types of datasets. We used four seed vectors for the PMA (), the same as the number of clusters.
Synthetic 2D mixtures of Gaussians: Each dataset contains points on a 2D plane, each sampled from one of four Gaussians.
CIFAR100 metaclustering: Each dataset contains images sampled from four random classes in the CIFAR100 dataset. Each image is represented by a 512dim vector obtained from a pretrained VGG net (Simonyan & Zisserman, 2014).
Synthetic  CIFAR100  

Architecture  LL0/data  LL1/data  ARI0  ARI1 
Oracle  1.4726  0.9150  
rFF + Pooling  2.0006 0.0123  1.6186 0.0042  0.5593 0.0149  0.5693 0.0171 
SAB + Pooling  1.6772 0.0066  1.5070 0.0115  0.5831 0.0341  0.5943 0.0337 
ISAB (16) + Pooling  1.6955 0.0730  1.4742 0.0158  0.5672 0.0124  0.5805 0.0122 
rFF + PMA  1.6680 0.0040  1.5409 0.0037  0.7612 0.0237  0.7670 0.0231 
Set Transformer  1.5145 0.0046  1.4619 0.0048  0.9015 0.0097  0.9024 0.0097 
Set Transformer (16)  1.5009 0.0068  1.4530 0.0037  0.9210 0.0055  0.9223 0.0056 
We report the performance of the oracle and of different models in Table 4. Additionally, it contains scores attained by all models after a single EM update. Overall, the Set Transformer found accurate parameters and even outperformed the oracles after a single EM update. This can be explained by relatively small size of the input sets, which leads to some clusters having fewer than 10 points. In this regime, sample statistics can differ from population statistics, which limits the performance of the oracle, but the Set Transformer can adapt accordingly. Notably, the Set Transformer with only 16 inducing points showed the best performance, even outperforming the full Set Transformer. We believe this is due to the knowledge transfer and regularization via inducing points, helping the network to learn global structures. Our results also imply that the improvements from using the PMA is more significant than that of using SAB, supporting our claim of the importance of attentionbased decoders. We provide detailed generative processes, network architectures, and training schemes along with additional experiments with various numbers of inducing points in Section B.3.
4.4 Meta Set Anomaly Detection
Architecture  Test AUROC  Test AUPR 

Random guess  0.5  0.125 
rFF + Pooling  0.5643 0.0139  0.4126 0.0108 
SAB + Pooling  0.5757 0.0143  0.4189 0.0167 
rFF + PMA  0.5756 0.0130  0.4227 0.0127 
Set Transformer  0.5941 0.0170  0.4386 0.0089 
We evaluate our methods on the task of metaanomaly detection within a set using the CelebA dataset. The dataset consists of 202,599 images with the total of 40 attributes. We randomly sample 1,000 sets of images. For every set, we select two attributes at random and construct the set by selecting seven images containing both attributes and one image with neither. The goal of this task is to find the image that does not belong to the set. We give a detailed description of the experimental setup in Section B.4. Table 5 contains empirical results, which show that Set Transformers outperformed all other methods by a significant margin.
4.5 Point Cloud Classification
We evaluated Set Transformers on a classification task using the ModelNet40 (Chang et al., 2015) dataset, containing 40 categories of threedimensional objects. Each object is represented as a point cloud, which we treat as a set of elements in . Table 6 contains experimental results on point clouds^{1}^{1}1The pointcloud dataset used in this experiment was obtained directly from the authors of Zaheer et al. (2017). with points each. In this setting, MABs turned out to be prohibitively expensive due to their time complexity. Additional results with points and experiment details are available in Section B.5. Note that ISAB (16) + Pooling outperformed Set Transformers (ISAB (16) + PMA (1)) by a large margin. Our interpretation is that the class of a point cloud object could be efficiently represented by simple aggregation of point features, and the PMA suffered from an optimization issue in this setting. We would like to point out that PMA outperformed simple pooling in all other experiments.
Architecture  Accuracy 

rFF + Pooling  0.8551 0.0142 
rFF + PMA (1)  0.8534 0.0152 
ISAB (16) + Pooling  0.8915 0.0144 
Set Transformer (16)  0.8662 0.0149 
rFF + Pooling (Zaheer et al., 2017)  0.83 0.01 
rFF + Pooling + tricks (Zaheer et al., 2017)  0.87 0.01 
5 Related Works
Pooling architectures for permutation invariant mappings Pooling architectures for sets have been used in various problems such as 3D shape recognition (Shi et al., 2015; Su et al., 2015), discovering causality (LopezPaz et al., 2016), learning the statistics of a set (Edwards & Storkey, 2017), fewshot image classification (Snell et al., 2017), and conditional regression and classification (Garnelo et al., 2018). Zaheer et al. (2017) discusses the structure in general and provides a partial proof of the universality of the pooling architecture.
Attentionbased approaches for sets Vinyals et al. (2016) proposes an architecture to map sets into sequences, where elements in a set are pooled by weighted average with weights computed from attention mechanism. Several recent works have highlighted the competency of attention mechanisms in modeling sets. (Yang et al., 2018) proposes AttSets for multiview 3D reconstruction, where attention is applied to the encoded features of elements in sets before pooling. Similarly, (Ilse et al., 2018) uses an attention in pooling for multiple instance learning. Although not permutation invariant, (Mishra et al., 2018) has an attention as one of its core components to metalearn to solve various tasks using sequences of inputs.
Modeling interactions between elements in sets An important reason to use the Transformer is to explicitly model higherorder interactions among the elements in a set. Santoro et al. (2017) proposes the relational network, a simple architecture that sumpools all pairwise interactions of elements in a given set, but not higherorder interactions. Similarly to our work, Ma et al. (2018) uses the Transformer to model interactions between the objects in a video. They use meanpooling to obtain aggregated features which they fed into an LSTM.
Inducing point methods The idea of letting trainable vectors directly interact with datapoints is loosely based on the inducing point methods used in sparse Gaussian processes (QuiñoneroCandela & Rasmussen, 2005) and the Nyström method for matrix decomposition (Fowlkes et al., 2004). trainable inducing points can also be seen as independent memory cells accessed with an attention mechanism. The Differential Neural Dictionary (Pritzel et al., 2017) stores previous experience as keyvalue pairs and uses this to process queries. One can view the ISAB is the inversion of this idea, where queries are stored and the input features are used as keyvalue pairs.
6 Conclusion
In this paper, we introduced the Set Transformer, an attentionbased setinput neural network architecture. Our proposed method uses attention mechanisms for both encoding and aggregating features, and we have empirically validated that both of them are necessary for modelling complicated interactions among elements of a set. We also proposed an inducing point method for selfattention, which makes our approach scalable to large sets. We also showed useful theoretical properties of our model, including the fact that it is a universal approximator for permutation invariant functions. To the best of our knowledge, no previous work has successfully trained a neural network to perform amortized clustering in a single forward pass. An interesting topic for future work would be to apply Set Transformers to metalearning problems other than metaclustering. In particular, using Set Transformers to metalearn posterior inference in Bayesian models seems like a promising line of research. Another exciting extension of our work would be to model the uncertainty in set functions by injecting noise variables into Set Transformers in a principled way.
References
 Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. arXiv:1607.06450, 2016.
 Chang et al. (2015) Angel X. Chang, Thomas Funkhouser, Leonidas Guibas, Pat Hanrahan, Qixing Huang, Zimo Li, Silvio Savarese, Manolis Savva, Shuran Song, Hao Su, Jianxiong Xiao, Li Yi, and Fisher Yu. Shapenet: an informationrich 3d model repository. arXiv:1512.03012, 2015.
 Charles et al. (2017) R Qi Charles, Hao Su, Mo Kaichun, and Leonidas J Guibas. Pointnet: Deep learning on point sets for 3d classification and segmentation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017.
 Dietterich et al. (1997) Thomas G. Dietterich, H. Lathrop Richard, and Tomás LozanoPérez. Solving the multiple instance problem with axisparallel rectangles. Artificial intelligence, 89(12):31–71, 1997.
 Edwards & Storkey (2017) Harrison Edwards and Amos Storkey. Towards a neural statistician. In Proceedings of the International Conference on Learning Representations (ICLR), 2017.
 Finn et al. (2017) Chelsea Finn, Pieter Abbeel, and Sergey Levine. Modelagnostic metalearning for fast adaptation of deep networks. In Proceedings of the International Conference on Machine Learning (ICML), 2017.
 Fowlkes et al. (2004) Charless Fowlkes, Serge Belongie, Fan Chung, and Jitendra Malik. Spectral grouping using the Nyström method. IEEE Transactions on Pattern Analysis and Machine Intelligence, 25(2):215–225, February 2004.
 Garnelo et al. (2018) Marta Garnelo, Dan Rosenbaum, Chris J. Maddison, Tiago Ramalho, David Saxton, Murray Shanahan, Yee Whye Teh, Danilo J. Rezende, and S. M. Ali Eslami. Conditional neural processes. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
 Graves et al. (2013) Alex Graves, Abdelrahman Mohamed, and Geoffrey Hinton. Speech recognition with deep recurrent neural networks. In Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP), 2013.
 Ilse et al. (2018) Maximilian Ilse, Jakub M. Tomczak, and Max Welling. Attentionbased deep multiple instance learning. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
 Kingma & Ba (2015) Diederik. P. Kingma and L. Ba, Jimmy. Adam: a method for stochastic optimization. In Proceedings of the International Conference on Learning Representations (ICLR), 2015.
 Krizhevsky et al. (2012) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems (NIPS), 2012.
 Lake et al. (2015) Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Humanlevel concept learning through probabilistic program induction. Science, 350(6266):1332–1338, 2015.
 Lee & Choi (2018) Yoonho Lee and Seungjin Choi. Gradientbased metalearning with learned layerwise metric and subspace. Proceedings of the International Conference on Machine Learning (ICML), 2018.
 LopezPaz et al. (2016) David LopezPaz, Robert Nishihara, Soumith Chintala, Bernhard Schölkopf, and Léon Bottou. Discovering causal signals in images. arXiv:1605.08179, 2016.
 Ma et al. (2018) ChihYao Ma, Asim Kadav, Iain Melvin, Zsolt Kira, Ghassan AlRegib, and Hans Peter Graf. Attend and interact: higherorder object interactions for video understanding. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018.
 Maron & LozanoPérez (1998) Oded Maron and Tomás LozanoPérez. A framework for multipleinstance learning. In Advances in Neural Information Processing Systems (NIPS), 1998.
 Mishra et al. (2018) Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive metalearner. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
 Muandet et al. (2012) Krikamol Muandet, Kenji Fukumizu, Francesco Dinuzzo, and Bernhard Schölkopf. Learning from distributions via support measure machines. In Advances in Neural Information Processing Systems (NIPS), 2012.
 Pritzel et al. (2017) Alexander Pritzel, Benigno Uria, Sriram Srinivasan, Adria Puigdomenech, Oriol Vinyals, Demis Hassabis, Daan Wierstra, and Charles Blundell. Neural episodic control. arXiv preprint arXiv:1703.01988, 2017.
 QuiñoneroCandela & Rasmussen (2005) Joaquin QuiñoneroCandela and Carl Edward Rasmussen. A unifying view of sparse approximate Gaussian process regression. Journal of Machine Learning Research, 6:1939–1959, 2005.
 Santoro et al. (2017) Adam Santoro, David Raposo, David G. T. Barret, Mateusz Malinowski, Razvan Pascanu, and Peter Battaglia. A simple neural network module for relational reasoning. In Advances in Neural Information Processing Systems (NIPS), 2017.
 Schmidhuber (1987) Jürgen Schmidhuber. Evolutionary Principles in SelfReferential Learning. PhD thesis, Technical University of Munich, 1987.
 Shi et al. (2015) Baoguang Shi, Song Bai, Zhichao Zhou, and Xiang Bai. DeepPano: deep panoramic representation for 3D shape recognition. IEEE Signal Processing Letters, 22(12):2339–2343, 2015.
 Simonyan & Zisserman (2014) Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for largescale image recognition. arXiv:1409.1556, 2014.
 Snell et al. (2017) Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for fewshot learning. In Advances in Neural Information Processing Systems (NIPS), 2017.
 Su et al. (2015) Hang Su, Subhransu Maji, Evangelos Kalogerakis, and Erik LearnedMiller. Multiview convolutional neural networks for 3d shape recognition. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), 2015.
 Thrun & Pratt (1998) Sebastian Thrun and Lorien Pratt. Learning to Learn. Kluwer Academic Publishers, 1998.
 Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems (NIPS), 2017.
 Vinyals et al. (2016) Oriol Vinyals, Samy Bengio, and Manjunath Kudlur. Order matters: sequence to sequence for sets. In Proceedings of the International Conference on Learning Representations (ICLR), 2016.
 Wu et al. (2015) Zhirong Wu, Shuran Song, Aditya Khosla, Fisher Yu, Linguang Zhang, Xiaoou Tang, and Jianxiong Xiao. 3D ShapeNets: a deep representation for volumetric shapes. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2015.
 Yang et al. (2018) Bo Yang, Sen Wang, Andrew Markham, and Niki Trigoni. Attentional aggregation of deep feature sets for multiview 3d reconstruction. arXiv:1808.00758, 2018.
 Zaheer et al. (2017) Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Ruslan R. Salakhutdinov, and Alexander J. Smola. Deep sets. In Advances in Neural Information Processing Systems (NIPS), 2017.
Appendix A Proofs
Lemma 1.
The mean operator is a special case of dotproduct attention with softmax.
Proof.
Let and .
∎
Lemma 2.
The decoder of a Set Transformer, given enough nodes, can express any elementwise function of the form .
Proof.
We first note that we can view the decoder as the composition of functions
(16)  
(17) 
We focus on in equation (17). Since feedforward networks are universal function approximators at the limit of infinite nodes, let the feedforward layers in front and back of the MAB encode the elementwise functions and , respectively. We let , so the number of heads is the same as the dimensionality of the inputs, and each head is onedimensional. Let the projection matrices in multihead attention () represent projections onto the jth dimension and the output matrix () the identity matrix. Since the mean operator is a special case of dotproduct attention, by simple composition, we see that an MAB can express any dimensionwise function of the form
(18) 
∎
Lemma 3.
A PMA, given enough nodes, can express sum pooling .
Proof.
We prove this by construction.
Set the seed to a zero vector and let , where is any activation function such that . The identiy, sigmoid, or relu functions are suitable choices for . The output of the multihead attention is then simply a sum of the values, which is in this case. ∎
We additionally have the following universality theorem for pooling architectures:
Theorem 1.
Models of the form are universal function approximators in the space of permutation invariant functions.
Proof.
See Appendix A of Zaheer et al. (2017). ∎
By Lemma 3, we know that can express any function of the form . Using this fact along with Theorem 1, we can prove the universality of Set Transformers:
Proposition 2.
The Set Transformer is a universal function approximator in the space of permutation invariant functions.
Proof.
By setting the matrix to a zero matrix in every SAB and ISAB, we can ignore all pairwise interaction terms in the encoder. Therefore, the can express any instancewise feedforward network (). Directly invoking Theorem 1 concludes this proof. ∎
While this proof required us to ignore the pairwise interaction terms inside the SABs and ISABs to prove that Set Transformers are universal function approximators, our experiments indicated that selfattention in the encoder was crucial for good performance.
Appendix B Experiment Details
In all implementations, we omit the feedforward layer in the beginning of the decoder () because the end of the previous block contains a feedforward layer. All MABs (inside SAB, ISAB and PMA) use fullyconnected layers with ReLU activations for rFF layers.
In the architecture descriptions, denotes the fullyconnected layer with units and activation function . denotes the SAB with units and heads. denotes the ISAB with units, heads and inducing points. denotes the PMA with units, heads and vectors. All MABs used in SAB and PMA uses FC layers with ReLU activations for FF layers.
b.1 Max Regression
Given a set of real numbers , the goal of this task is to return the maximum value in the set . We construct training data as follows. We first sample a dataset size uniformly from the set of integers . We then sample real numbers independently from the interval . Given the network’s prediction , we use the actual maximum value to compute the mean absolute error . We don’t explicitly consider splits of train and test data, since we sample a new set at each time step.
Encoder  Decoder  

FF  SAB  Pooling  PMA 
b.2 Counting Unique Characters
Architecture  Accuracy 

rFF + Pooling  0.4366 0.0071 
rFF + PMA  0.4617 0.0073 
SAB + Pooling  0.5659 0.0067 
Set Transformers (SAB + PMA (1))  0.6037 0.0072 
Set Transformers (SAB + PMA (2))  0.5806 0.0075 
Set Transformers (SAB + PMA (4))  0.5945 0.0072 
Set Transformers (SAB + PMA (8))  0.6001 0.0078 
Encoder  Decoder  

rFF  SAB  Pooling  PMA 
The task generation procedure is as follows. We first sample a set size uniformly from the set of integers . We then sample the number of characters uniformly from . We sample characters from the training set of characters, and randomly sample instances of each character so that the total number of instances sums to and each set of characters has at least one instance in the resulting set.
We show the detailed architectures used for the experiments in Table 9. For both architectures, the resulting dimensional output is passed through a activation to produce the Poisson parameter . The role of is to ensure that is always positive.
The loss function we optimize, as previously mentioned, is the log likelihood . We chose this loss function over mean squared error or mean absolute error because it seemed like the more logical choice when trying to make a real number match a target integer. Early experiments showed that directly optimizing for mean absolute error had roughly the same result as optimizing in this way and measuring . We train using the Adam optimizer with a constant learning rate of for batches each with batch size .
b.3 Solving maximum likelihood problems for mixture of Gaussians
b.3.1 Details for 2D synthetic mixtures of Gaussians experiment
We generated the datasets according to the following generative process.

Generate the number of data points, .

Generate centers.
(19) 
Generate cluster labels.
(20) 
Generate data from spherical Gaussian.
(21)
Table 10 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 11 summarizes the detailed results with various number of inducing points in the ISAB. Figure 3 shows the actual clustering results based on the predicted parameters.
Encoder  Decoder  

rFF  SAB  ISAB  Pooling  PMA 
Architecture  LL0/data  LL1/data 

Oracle  1.4726  
rFF + Pooling  2.0006 0.0123  1.6186 0.0042 
SAB + Pooling  1.6772 0.0066  1.5070 0.0115 
ISAB (16) + Pooling  1.6955 0.0730  1.4742 0.0158 
ISAB (32) + Pooling  1.6353 0.0182  1.4681 0.0038 
ISAB (64) + Pooling  1.6349 0.0429  1.4664 0.0080 
rFF + PMA  1.6680 0.0040  1.5409 0.0037 
Set Transformer  1.5145 0.0046  1.4619 0.0048 
Set Transformer (16)  1.5009 0.0068  1.4530 0.0037 
Set Transformer (32)  1.4963 0.0064  1.4524 0.0044 
Set Transformer (64)  1.5042 0.0158  1.4535 0.0053 
b.3.2 Details for CIFAR100 meta clustering experiment
We pretrained VGG net (Simonyan & Zisserman, 2014) with CIFAR100, and obtained the test accuracy 68.54%. Then, we extracted feature vectors of 50k training images of CIFAR100 from the 512dimensional hidden layers of the VGG net (the layer just before the last layer). Given these feature vectors, the generative process of datasets is as follows.

Generate the number of data points, .

Uniformly sample four classes among 100 classes.

Uniformly sample data points among four sampled classes.
Table 12 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 13 summarizes the detailed results with various number of inducing points in the ISAB.
Encoder  Decoder  

rFF  SAB  ISAB  rFF  PMA 
)  
Architecture  ARI0  ARI1 

Oracle  0.9151  
rFF + Pooling  0.5593 0.0149  0.5693 0.0171 
SAB + Pooling  0.5831 0.0341  0.5943 0.0337 
ISAB (16) + Pooling  0.5672 0.0124  0.5805 0.0122 
ISAB (32) + Pooling  0.5587 0.0104  0.5700 0.0134 
ISAB (64) + Pooling  0.5586 0.0205  0.5708 0.0183 
rFF + PMA  0.7612 0.0237  0.7670 0.0231 
Set Transformer  0.9015 0.0097  0.9024 0.0097 
Set Transformer (16)  0.9210 0.0055  0.9223 0.0056 
Set Transformer (32)  0.9103 0.0061  0.9119 0.0052 
Set Transformer (64)  0.9141 0.0040  0.9153 0.0041 
b.4 Meta Set Anomaly
Encoder  Decoder  

rFF  SAB  Pooling  PMA 
Table 14 describes the architecture for meta set anomaly experiments. We trained all models via Adam optimizer with learning rate and exponential decay of learning rate for 1,000 iterations. 1,000 datasets subsampled from CelebA dataset (see Figure 4) are used to train and test all the methods. We split 800 training datasets and 200 test datasets for the subsampled datasets.
b.5 Point Cloud Classification
We used the ModelNet40 dataset for our point cloud classification experiments. This dataset consists of a 3dimensional representation of 9,843 training and 2,468 test data which each belong to one of object classes. As input to our architectures, we produce point clouds with points each (each point is represented by coordinates). For generalization, we randomly rotate and scale each set during training.
We show results our architectures in Table 15 and additional experiments which used points in Table 16. We trained using the Adam optimizer with an initial learning rate of which we decayed by a factor of every steps.
Encoder  Decoder  

rFF  ISAB  Pooling  PMA 
Architecture  Accuracy 

rFF + Pooling  0.7951 0.0166 
rFF + PMA (1)  0.8076 0.0160 
ISAB (16) + Pooling  0.8273 0.0159 
Set Transformer (16)  0.8454 0.0144 
rFF + Pooling + tricks (Zaheer et al., 2017)  0.82 0.02 