Many machine learning tasks such as multiple instance learning, 3D shape recognition and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the permutation of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating increased performance compared to recent methods for set-structured data.
Learning representations has proven to be an essential problem for deep learning and its many success stories. The majority of problems tackled by deep learning are instance-based and take the form of mapping a fixed-dimensional input tensor to its corresponding target value (Krizhevsky et al., 2012; Graves et al., 2013). For some applications, we are required to process set-structured data. Multiple instance learning (Dietterich et al., 1997; Maron & Lozano-Pérez, 1998) is an example of such a set-input problem, where a set of instances is given as an input and the corresponding target is a label for the entire set. Other problems such as 3D shape recognition (Wu et al., 2015; Shi et al., 2015; Su et al., 2015; Charles et al., 2017), sequence ordering (Vinyals et al., 2016), and various set operations (Muandet et al., 2012; Edwards & Storkey, 2017; Zaheer et al., 2017) can also be viewed as such set-input problems. Moreover, many meta-learning (Thrun & Pratt, 1998; Schmidhuber, 1987) problems which learn using a set of different but related tasks may also be treated as set-input tasks where an input set corresponds to the training dataset of a single task. For example, few-shot image classification (Finn et al., 2017; Snell et al., 2017; Lee & Choi, 2018) operates by building a classifier using a support set of images, which is evaluated with query images.
A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant — the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size. While these requirements stem from the definition of a set, they are not easily satisfied in neursal-network-based models: classical feed-forward neural networks violate both requirements, and RNNs are sensitive to input order.
Recently, Edwards & Storkey (2017) and Zaheer et al. (2017) propose neural network architectures which meet both criteria, which we call set pooling methods. In this model, each element in a set is first independently fed into a feed-forward neural network that takes fixed-size inputs. Resulting feature-space embeddings are then aggregated using a pooling operation (, , or similar). The final output is obtained by further non-linear processing of the aggregated embedding. This remarkably simple architecture satisfies both aforementioned requirements, and more importantly, is proven to be a universal approximator for any set function (Zaheer et al., 2017). Thanks to this property, it is possible to learn complex mapping between input sets and their target outputs in a black-box fashion, much like with feed-forward or recurrent neural networks.
Even though this set pooling approach is theoretically attractive, it remains unclear whether we can approximate complex mappings well using only instance-based feature extractors and simple pooling operations. Since every element in a set is processed independently in a set pooling operation, some information regarding interactions between elements has to be necessarily discarded. This can make some classes of problems unnecessarily difficult to solve.
Consider the problem of meta-clustering: we would like to learn a parametric mapping from an input set of points to centers of any clusters in the set, for many such sets. Even though a neural network with a set pooling operation can approximate such a mapping by learning to quantize space, this quantization cannot depend on the contents of the set. It limits the quality of the solution on one hand, and may make optimization of such a model more difficult; we show empirically in Section 4 that it leads to under-fitting.
In this paper, we propose a novel set-input deep neural network architecture called the Set Transformer, (cf. Transformer, Vaswani et al. (2017)). The novelty of the Set Transformer comes from three important design choices: 1) We use a self-attention mechanism based on the Transformer to process every element in an input set, which allows our approach to naturally encode pairwise- or higher-order interactions between elements in the set. 2) We propose a method to reduce the computation time of Transformers to where is a fixed hyperparameter. 3) We use a self-attention mechanism to aggregate features, which is especially beneficial when the problem of interest requires multiple dependent outputs, such as the problem of meta-clustering, where the meaning of each cluster center heavily depends its location relative to the other clusters. We apply the Set Transformer to several set-input problems and empirically demonstrate the importance and effectiveness of these design choices.
This paper is organized as follows. In Section 2, we briefly review the concept of set functions, existing architectures, and the self-attention mechanism. In Section 3, we introduce Set Transformers, our novel neural network architecture for set functions. In Section 4, we present various experiments that demonstrate the benefits of the Set Transformer. We discuss related works in Section 5 and conclude the paper in Section 6.
2.1 Pooling Architecture for Sets
Problems involving a set of objects have the permutation invariance property: the target value for a given set is the same regardless of the order of objects in the set. A simple example of a permutation invariant model is a network that performs pooling over embeddings extracted from the elements of a set. More formally,
Zaheer et al. (2017) has proven that all permutation invariant functions can be represented as (1) when is the operator and any continuous functions, thus justifying the use of this architecture for set-input problems.
Note that we can deconstruct (1) into two parts: an encoder () which independently acts on each element of a set of items, and a decoder () which aggregates these encoded features and produces our desired output. Most network architectures for set-structured data follow this encoder-decoder structure. Our proposed method is also composed of an encoder and a decoder, but our embedding function does not act independently on each item but considers the whole set to obtain the embedding. Additionally, instead of a fixed function such as , our aggregating function is parameterized and can thus adapt to the problem at hand.
Assume we have query vectors (corresponding to points in an input set) each with dimension : . An attention function is a function that maps queries to outputs using key-value pairs .
The pairwise dot product measures how similar each pair of query and key vectors is, with weights computed with an activation function . The output is a weighted sum of where a value gets more weight if its corresponding key has larger dot product with the query.
Multi-head attention, originally introduced in Vaswani et al. (2017), is an extension of the previous attention scheme. Instead of computing a single attention function, this method first projects onto different -dimensional vectors, respectively. An attention function () is applied to each of these projections. The output is a linear transformation of the concatenation of all attention outputs:
Note that has learnable parameters , where , , . A typical choice for the dimension hyperparameters is , , . For brevity, we set and throughout the rest of the paper. Unless specified otherwise, we use the scaled softmax , which our experiments showed worked robustly in most settings.
3 Set transformer
|Set operations||Time complexity||High-order||Permutation|
|Pooling (Zaheer et al., 2017)||No||Yes|
|Relational Networks (Santoro et al., 2017)||Yes||Yes|
|Set Transformer (SAB + PMA, ours)||Yes||Yes|
|Set Transformer (ISAB + PMA, ours)||Yes||Yes|
In this section, we motivate and describe the Set Transformer: an attention-based neural network architecture that is designed to process sets of data. A Set Transformer consists of an encoder followed by a decoder (cf. Section 2.1). The encoder transforms a set of instances into a set of features, which the decoder transforms into the desired fixed-dimensional output.
3.1 Attention-based set operations
We begin by defining our attention-based set operations. While existing pooling methods for sets obtain instance features independently of other instances, we use self-attention to concurrently encode the whole set. This gives the Set Transformer the ability to preserve pairwise as well as higher-order interactions among instances during the encoding process. For this purpose, we adapt the multihead attention mechanism used in Transformer. We emphasize that all blocks introduced here are neural network blocks with their own parameters, and not fixed functions.
Given matrices which represent two sets of -dimensional vectors, we define the Multihead Attention Block (MAB) with parameters as follows:
where is any row-wise feedforward layer (i.e. it processes each instance independently and identically), and is layer normalization (Ba et al., 2016). The MAB is an adaptation of the encoder block of the Transformer (Vaswani et al., 2017) without positional encoding and dropout. Using the MAB, we define the Set Attention Block (SAB) as
In other words, an SAB takes a set and performs self-attention between the elements in the set, resulting in a set of equal size. Since the output of SAB contains information about pairwise interactions between the elements in the input set , we can stack multiple SABs to encode higher order interactions. Note that while the SAB (7) involves a multihead attention operation (6), where , it could reduce to applying a residual block on . In practice, it learns more complicated functions due to linear projections of inside attention heads, (2) and (4).
A potential problem with using SABs for set-structured data is the quadratic time complexity , which may be too expensive for large sets (). We thus introduce the Induced Set Attention Block (ISAB), which bypasses this problem. Along with the set , additionally define -dimensional vectors , which we call inducing points. Inducing points are part of the ISAB itself, and they are trainable parameters which we train along with other parameters of the network. An ISAB with inducing points is defined as:
The ISAB first transforms into by attending to the input set. The set of transformed inducing points , which contains information about the input set , is again attended to by the input set to finally produce a set of elements. This is analogous to low-rank projection or autoencoder models, where inputs () are first projected onto a low-dimensional object () and then reconstructed to produce outputs. The difference is that the goal of these methods is reconstruction whereas ISAB aims to obtain good features for the final task. We expect the learned inducing points to encode some global structure which helps explain the inputs . As an example, think of a clustering problem on a 2D plane. The inducing points could be appropriately distributed points on the 2D plane so that the encoder can compare elements in the query dataset indirectly through their proximity to these grid points.
Note that in (8) and (9), attention was computed between a set of size and a set of size . Therefore, the time complexity of is where is a hyperparameter — an improvement over the quadratic complexity of the SAB. We compare characteristics of various set operations in Table 1. We also emphasize that both of our set operations are permutation equivariant:
We say a function is permutation equivariant iff for any permutation , . Here is the set of all permutations of indices .
Both and are permutation equivariant.
Using the SAB and ISAB defined above, we construct the encoder of the Set Transformer by stacking multiple SABs or multiple ISABs, for example:
We point out again that the time complexity for stacks of SABs and ISABs are and , respectively. This can result in much lower processing times when using ISAB (as compared to SAB), while still maintaining high representational power.
After the encoder transforms data into features , the decoder aggregates them into a single vector which is fed into a feed-forward network to get final outputs. A common aggregation scheme is to simply take the average or dimension-wise maximum of the feature vectors (cf. Section 1). We instead aggregate features by applying multihead attention on a learnable set of seed vectors . We call this scheme Pooling by Multihead Attention (PMA):
Note that the output of is a set of items. In most cases, using one seed vector () and no SAB sufficed. However, when the problem of interest requires correlated outputs, the natural thing to do is to use inducing points. An example of such a problem is clustering where the desired output is centers. In this case, the additional SAB was crucial because it allowed the network to directly take the correlation between the pooled features into account. Intuitively, feature aggregation using attention should be beneficial because the influence of each instance on the target is not necessarily equal. For example, consider a problem where the target value is the maximum value of a set of real numbers. Since the target can be recovered using only a single instance (the largest), finding and attending to that instance during aggregation will be advantageous. In the next subsection, we further analyze both the encoder and decoder structures more rigorously.
Since the blocks used to construct the encoder (i.e., SAB, ISAB) are permutation equivariant, the mapping of the encoder is permutation equivariant as well. Combined with the fact that the PMA in the decoder is a permutation invariant transformation, we have the following:
The Set Transformer is permutation invariant.
Being able to approximate any function is a desirable property, especially for black-box models such as deep neural networks. Building on previous results about the universal approximation of permutation invariant functions, we prove the universality of Set Transformers:
The Set Transformer is a universal approximator of permutation invariant functions.
See Appendix A. ∎
To evaluate the Set Transformer, we apply it to a suite of tasks involving sets of data points. We repeat all experiments five times and report performance metrics evaluated on corresponding test datasets. We compared various architectures arising from the combination of the choices of having attention in encoders and decoders, each of which roughly represents existing works as its special cases. Unless specified otherwise, ”simple pooling” means average pooling.
rFF + Pooling ( Zaheer et al. (2017)): rFF layers in encoder and simple pooling + rFF layers in decoder.
rFF + PMA (includes Ilse et al. (2018) as special cases): rFF layers in encoder and PMA (followed by stack of SABs) in decoder.
Set Transformer: Stack of SABs (ISABs) in encoder and PMA (followed by stack of SABs) in decoder.
4.1 Toy Problem: Maximum Value Regression
To demonstrate the advantage of attention-based set aggregation over simple pooling operations, we consider a toy problem: regression to the maximum value of a given set. Given a set of real numbers , the goal is to return . Given prediction , we use the mean absolute error as the loss function. We constructed simple pooling architectures with three different pooling operations: , n, and . We report loss values after training in Table LABEL:table:max. Mean- and sum-pooling architectures result in a high mean absolute error (MAE). The model with max-pooling can predict the output perfectly by learning its encoder to be an identity function, and thus achieves the highest performance. Notably, the Set Transformer achieves performance comparable to the max-pooling model, which underlines the importance of additional flexibility granted by attention mechanisms — it can learn to find and attend to the maximum element.
|rFF + Pooling (mean)||2.133 0.190|
|rFF + Pooling (sum)||1.902 0.137|
|rFF + Pooling (max)||0.1355 0.0074|
|Set Transformer||0.2085 0.0127|
|rFF + Pooling||0.5618 0.0072|
|rFF + PMA||0.5428 0.0076|
|SAB + Pooling||0.4477 0.0077|
|Set Transformer||0.4178 0.0075|
4.2 Counting Unique Characters
In order to test the ability of modelling interactions between objects in a set, we introduce a new task of counting unique elements in an input set. We use the Omniglot (Lake et al., 2015) dataset, which consists of 1,623 different handwritten characters from various alphabets, where each character is represented by 20 different images.
We split all characters (and corresponding images) into train, validation, and test sets and only train using images from the train character classes. We generate input sets by sampling between 6 and 10 images and we train the model to predict the number of different characters inside the set. We used a Poisson regression model to predict this number, with the rate given as the output of a neural network. We maximized the log likelihood of this model using stochastic gradient ascent.
We evaluated model performance using sets of images sampled from the test set of characters. LABEL:table:unique reports accuracy, measured as the frequency at which the mode of the Poisson distribution chosen by the network is equal to the number of characters inside the input set.
4.3 Solving Maximum Likelihood Problems for Mixture of Gaussians
We applied the set-input networks to the task of maximum likelihood of mixture of Gaussians (MoGs). The log-likelihood of a dataset generated from an MoG with components is
The goal is to learn the optimal parameters . The typical approach to this problem is to run an iterative algorithm such as Expectation-Maximisation (EM) until convergence. Instead, we aim to learn a generic meta-algorithm that directly maps the input set to . One can also view this as amortized maximum likelihood learning. Specifically, given a dataset , we train a neural network to output parameters which maximize
We structured as a set-input neural network and learned its parameters using stochastic gradient ascent, where we approximate gradients using minibatches of datasets.
We tested Set Transformers along with other set-input networks on two types of datasets. We used four seed vectors for the PMA (), the same as the number of clusters.
Synthetic 2D mixtures of Gaussians: Each dataset contains points on a 2D plane, each sampled from one of four Gaussians.
CIFAR-100 meta-clustering: Each dataset contains images sampled from four random classes in the CIFAR-100 dataset. Each image is represented by a 512-dim vector obtained from a pretrained VGG net (Simonyan & Zisserman, 2014).
|rFF + Pooling||-2.0006 0.0123||-1.6186 0.0042||0.5593 0.0149||0.5693 0.0171|
|SAB + Pooling||-1.6772 0.0066||-1.5070 0.0115||0.5831 0.0341||0.5943 0.0337|
|ISAB (16) + Pooling||-1.6955 0.0730||-1.4742 0.0158||0.5672 0.0124||0.5805 0.0122|
|rFF + PMA||-1.6680 0.0040||-1.5409 0.0037||0.7612 0.0237||0.7670 0.0231|
|Set Transformer||-1.5145 0.0046||-1.4619 0.0048||0.9015 0.0097||0.9024 0.0097|
|Set Transformer (16)||-1.5009 0.0068||-1.4530 0.0037||0.9210 0.0055||0.9223 0.0056|
We report the performance of the oracle and of different models in Table 4. Additionally, it contains scores attained by all models after a single EM update. Overall, the Set Transformer found accurate parameters and even outperformed the oracles after a single EM update. This can be explained by relatively small size of the input sets, which leads to some clusters having fewer than 10 points. In this regime, sample statistics can differ from population statistics, which limits the performance of the oracle, but the Set Transformer can adapt accordingly. Notably, the Set Transformer with only 16 inducing points showed the best performance, even outperforming the full Set Transformer. We believe this is due to the knowledge transfer and regularization via inducing points, helping the network to learn global structures. Our results also imply that the improvements from using the PMA is more significant than that of using SAB, supporting our claim of the importance of attention-based decoders. We provide detailed generative processes, network architectures, and training schemes along with additional experiments with various numbers of inducing points in Section B.3.
4.4 Meta Set Anomaly Detection
|Architecture||Test AUROC||Test AUPR|
|rFF + Pooling||0.5643 0.0139||0.4126 0.0108|
|SAB + Pooling||0.5757 0.0143||0.4189 0.0167|
|rFF + PMA||0.5756 0.0130||0.4227 0.0127|
|Set Transformer||0.5941 0.0170||0.4386 0.0089|
We evaluate our methods on the task of meta-anomaly detection within a set using the CelebA dataset. The dataset consists of 202,599 images with the total of 40 attributes. We randomly sample 1,000 sets of images. For every set, we select two attributes at random and construct the set by selecting seven images containing both attributes and one image with neither. The goal of this task is to find the image that does not belong to the set. We give a detailed description of the experimental setup in Section B.4. Table 5 contains empirical results, which show that Set Transformers outperformed all other methods by a significant margin.
4.5 Point Cloud Classification
We evaluated Set Transformers on a classification task using the ModelNet40 (Chang et al., 2015) dataset, containing 40 categories of three-dimensional objects. Each object is represented as a point cloud, which we treat as a set of elements in . Table 6 contains experimental results on point clouds111The point-cloud dataset used in this experiment was obtained directly from the authors of Zaheer et al. (2017). with points each. In this setting, MABs turned out to be prohibitively expensive due to their time complexity. Additional results with points and experiment details are available in Section B.5. Note that ISAB (16) + Pooling outperformed Set Transformers (ISAB (16) + PMA (1)) by a large margin. Our interpretation is that the class of a point cloud object could be efficiently represented by simple aggregation of point features, and the PMA suffered from an optimization issue in this setting. We would like to point out that PMA outperformed simple pooling in all other experiments.
|rFF + Pooling||0.8551 0.0142|
|rFF + PMA (1)||0.8534 0.0152|
|ISAB (16) + Pooling||0.8915 0.0144|
|Set Transformer (16)||0.8662 0.0149|
|rFF + Pooling (Zaheer et al., 2017)||0.83 0.01|
|rFF + Pooling + tricks (Zaheer et al., 2017)||0.87 0.01|
5 Related Works
Pooling architectures for permutation invariant mappings Pooling architectures for sets have been used in various problems such as 3D shape recognition (Shi et al., 2015; Su et al., 2015), discovering causality (Lopez-Paz et al., 2016), learning the statistics of a set (Edwards & Storkey, 2017), few-shot image classification (Snell et al., 2017), and conditional regression and classification (Garnelo et al., 2018). Zaheer et al. (2017) discusses the structure in general and provides a partial proof of the universality of the pooling architecture.
Attention-based approaches for sets Vinyals et al. (2016) proposes an architecture to map sets into sequences, where elements in a set are pooled by weighted average with weights computed from attention mechanism. Several recent works have highlighted the competency of attention mechanisms in modeling sets. (Yang et al., 2018) proposes AttSets for multi-view 3D reconstruction, where attention is applied to the encoded features of elements in sets before pooling. Similarly, (Ilse et al., 2018) uses an attention in pooling for multiple instance learning. Although not permutation invariant, (Mishra et al., 2018) has an attention as one of its core components to meta-learn to solve various tasks using sequences of inputs.
Modeling interactions between elements in sets An important reason to use the Transformer is to explicitly model higher-order interactions among the elements in a set. Santoro et al. (2017) proposes the relational network, a simple architecture that sum-pools all pairwise interactions of elements in a given set, but not higher-order interactions. Similarly to our work, Ma et al. (2018) uses the Transformer to model interactions between the objects in a video. They use mean-pooling to obtain aggregated features which they fed into an LSTM.
Inducing point methods The idea of letting trainable vectors directly interact with datapoints is loosely based on the inducing point methods used in sparse Gaussian processes (Quiñonero-Candela & Rasmussen, 2005) and the Nyström method for matrix decomposition (Fowlkes et al., 2004). trainable inducing points can also be seen as independent memory cells accessed with an attention mechanism. The Differential Neural Dictionary (Pritzel et al., 2017) stores previous experience as key-value pairs and uses this to process queries. One can view the ISAB is the inversion of this idea, where queries are stored and the input features are used as key-value pairs.
In this paper, we introduced the Set Transformer, an attention-based set-input neural network architecture. Our proposed method uses attention mechanisms for both encoding and aggregating features, and we have empirically validated that both of them are necessary for modelling complicated interactions among elements of a set. We also proposed an inducing point method for self-attention, which makes our approach scalable to large sets. We also showed useful theoretical properties of our model, including the fact that it is a universal approximator for permutation invariant functions. To the best of our knowledge, no previous work has successfully trained a neural network to perform amortized clustering in a single forward pass. An interesting topic for future work would be to apply Set Transformers to meta-learning problems other than meta-clustering. In particular, using Set Transformers to meta-learn posterior inference in Bayesian models seems like a promising line of research. Another exciting extension of our work would be to model the uncertainty in set functions by injecting noise variables into Set Transformers in a principled way.
- Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. arXiv:1607.06450, 2016.
- Chang et al. (2015) Angel X. Chang, Thomas Funkhouser, Leonidas Guibas, Pat Hanrahan, Qixing Huang, Zimo Li, Silvio Savarese, Manolis Savva, Shuran Song, Hao Su, Jianxiong Xiao, Li Yi, and Fisher Yu. Shapenet: an information-rich 3d model repository. arXiv:1512.03012, 2015.
- Charles et al. (2017) R Qi Charles, Hao Su, Mo Kaichun, and Leonidas J Guibas. Pointnet: Deep learning on point sets for 3d classification and segmentation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017.
- Dietterich et al. (1997) Thomas G. Dietterich, H. Lathrop Richard, and Tomás Lozano-Pérez. Solving the multiple instance problem with axis-parallel rectangles. Artificial intelligence, 89(1-2):31–71, 1997.
- Edwards & Storkey (2017) Harrison Edwards and Amos Storkey. Towards a neural statistician. In Proceedings of the International Conference on Learning Representations (ICLR), 2017.
- Finn et al. (2017) Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the International Conference on Machine Learning (ICML), 2017.
- Fowlkes et al. (2004) Charless Fowlkes, Serge Belongie, Fan Chung, and Jitendra Malik. Spectral grouping using the Nyström method. IEEE Transactions on Pattern Analysis and Machine Intelligence, 25(2):215–225, February 2004.
- Garnelo et al. (2018) Marta Garnelo, Dan Rosenbaum, Chris J. Maddison, Tiago Ramalho, David Saxton, Murray Shanahan, Yee Whye Teh, Danilo J. Rezende, and S. M. Ali Eslami. Conditional neural processes. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
- Graves et al. (2013) Alex Graves, Abdel-rahman Mohamed, and Geoffrey Hinton. Speech recognition with deep recurrent neural networks. In Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP), 2013.
- Ilse et al. (2018) Maximilian Ilse, Jakub M. Tomczak, and Max Welling. Attention-based deep multiple instance learning. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
- Kingma & Ba (2015) Diederik. P. Kingma and L. Ba, Jimmy. Adam: a method for stochastic optimization. In Proceedings of the International Conference on Learning Representations (ICLR), 2015.
- Krizhevsky et al. (2012) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems (NIPS), 2012.
- Lake et al. (2015) Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332–1338, 2015.
- Lee & Choi (2018) Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and subspace. Proceedings of the International Conference on Machine Learning (ICML), 2018.
- Lopez-Paz et al. (2016) David Lopez-Paz, Robert Nishihara, Soumith Chintala, Bernhard Schölkopf, and Léon Bottou. Discovering causal signals in images. arXiv:1605.08179, 2016.
- Ma et al. (2018) Chih-Yao Ma, Asim Kadav, Iain Melvin, Zsolt Kira, Ghassan AlRegib, and Hans Peter Graf. Attend and interact: higher-order object interactions for video understanding. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018.
- Maron & Lozano-Pérez (1998) Oded Maron and Tomás Lozano-Pérez. A framework for multiple-instance learning. In Advances in Neural Information Processing Systems (NIPS), 1998.
- Mishra et al. (2018) Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive meta-learner. In Proceedings of the International Conference on Machine Learning (ICML), 2018.
- Muandet et al. (2012) Krikamol Muandet, Kenji Fukumizu, Francesco Dinuzzo, and Bernhard Schölkopf. Learning from distributions via support measure machines. In Advances in Neural Information Processing Systems (NIPS), 2012.
- Pritzel et al. (2017) Alexander Pritzel, Benigno Uria, Sriram Srinivasan, Adria Puigdomenech, Oriol Vinyals, Demis Hassabis, Daan Wierstra, and Charles Blundell. Neural episodic control. arXiv preprint arXiv:1703.01988, 2017.
- Quiñonero-Candela & Rasmussen (2005) Joaquin Quiñonero-Candela and Carl Edward Rasmussen. A unifying view of sparse approximate Gaussian process regression. Journal of Machine Learning Research, 6:1939–1959, 2005.
- Santoro et al. (2017) Adam Santoro, David Raposo, David G. T. Barret, Mateusz Malinowski, Razvan Pascanu, and Peter Battaglia. A simple neural network module for relational reasoning. In Advances in Neural Information Processing Systems (NIPS), 2017.
- Schmidhuber (1987) Jürgen Schmidhuber. Evolutionary Principles in Self-Referential Learning. PhD thesis, Technical University of Munich, 1987.
- Shi et al. (2015) Baoguang Shi, Song Bai, Zhichao Zhou, and Xiang Bai. DeepPano: deep panoramic representation for 3-D shape recognition. IEEE Signal Processing Letters, 22(12):2339–2343, 2015.
- Simonyan & Zisserman (2014) Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv:1409.1556, 2014.
- Snell et al. (2017) Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems (NIPS), 2017.
- Su et al. (2015) Hang Su, Subhransu Maji, Evangelos Kalogerakis, and Erik Learned-Miller. Multi-view convolutional neural networks for 3d shape recognition. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), 2015.
- Thrun & Pratt (1998) Sebastian Thrun and Lorien Pratt. Learning to Learn. Kluwer Academic Publishers, 1998.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems (NIPS), 2017.
- Vinyals et al. (2016) Oriol Vinyals, Samy Bengio, and Manjunath Kudlur. Order matters: sequence to sequence for sets. In Proceedings of the International Conference on Learning Representations (ICLR), 2016.
- Wu et al. (2015) Zhirong Wu, Shuran Song, Aditya Khosla, Fisher Yu, Linguang Zhang, Xiaoou Tang, and Jianxiong Xiao. 3D ShapeNets: a deep representation for volumetric shapes. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2015.
- Yang et al. (2018) Bo Yang, Sen Wang, Andrew Markham, and Niki Trigoni. Attentional aggregation of deep feature sets for multi-view 3d reconstruction. arXiv:1808.00758, 2018.
- Zaheer et al. (2017) Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Ruslan R. Salakhutdinov, and Alexander J. Smola. Deep sets. In Advances in Neural Information Processing Systems (NIPS), 2017.
Appendix A Proofs
The mean operator is a special case of dot-product attention with softmax.
Let and .
The decoder of a Set Transformer, given enough nodes, can express any element-wise function of the form .
We first note that we can view the decoder as the composition of functions
We focus on in equation (17). Since feed-forward networks are universal function approximators at the limit of infinite nodes, let the feed-forward layers in front and back of the MAB encode the element-wise functions and , respectively. We let , so the number of heads is the same as the dimensionality of the inputs, and each head is one-dimensional. Let the projection matrices in multi-head attention () represent projections onto the jth dimension and the output matrix () the identity matrix. Since the mean operator is a special case of dot-product attention, by simple composition, we see that an MAB can express any dimension-wise function of the form
A PMA, given enough nodes, can express sum pooling .
We prove this by construction.
Set the seed to a zero vector and let , where is any activation function such that . The identiy, sigmoid, or relu functions are suitable choices for . The output of the multihead attention is then simply a sum of the values, which is in this case. ∎
We additionally have the following universality theorem for pooling architectures:
Models of the form are universal function approximators in the space of permutation invariant functions.
See Appendix A of Zaheer et al. (2017). ∎
The Set Transformer is a universal function approximator in the space of permutation invariant functions.
By setting the matrix to a zero matrix in every SAB and ISAB, we can ignore all pairwise interaction terms in the encoder. Therefore, the can express any instance-wise feed-forward network (). Directly invoking Theorem 1 concludes this proof. ∎
While this proof required us to ignore the pairwise interaction terms inside the SABs and ISABs to prove that Set Transformers are universal function approximators, our experiments indicated that self-attention in the encoder was crucial for good performance.
Appendix B Experiment Details
In all implementations, we omit the feed-forward layer in the beginning of the decoder () because the end of the previous block contains a feed-forward layer. All MABs (inside SAB, ISAB and PMA) use fully-connected layers with ReLU activations for rFF layers.
In the architecture descriptions, denotes the fully-connected layer with units and activation function . denotes the SAB with units and heads. denotes the ISAB with units, heads and inducing points. denotes the PMA with units, heads and vectors. All MABs used in SAB and PMA uses FC layers with ReLU activations for FF layers.
b.1 Max Regression
Given a set of real numbers , the goal of this task is to return the maximum value in the set . We construct training data as follows. We first sample a dataset size uniformly from the set of integers . We then sample real numbers independently from the interval . Given the network’s prediction , we use the actual maximum value to compute the mean absolute error . We don’t explicitly consider splits of train and test data, since we sample a new set at each time step.
b.2 Counting Unique Characters
|rFF + Pooling||0.4366 0.0071|
|rFF + PMA||0.4617 0.0073|
|SAB + Pooling||0.5659 0.0067|
|Set Transformers (SAB + PMA (1))||0.6037 0.0072|
|Set Transformers (SAB + PMA (2))||0.5806 0.0075|
|Set Transformers (SAB + PMA (4))||0.5945 0.0072|
|Set Transformers (SAB + PMA (8))||0.6001 0.0078|
The task generation procedure is as follows. We first sample a set size uniformly from the set of integers . We then sample the number of characters uniformly from . We sample characters from the training set of characters, and randomly sample instances of each character so that the total number of instances sums to and each set of characters has at least one instance in the resulting set.
We show the detailed architectures used for the experiments in Table 9. For both architectures, the resulting -dimensional output is passed through a activation to produce the Poisson parameter . The role of is to ensure that is always positive.
The loss function we optimize, as previously mentioned, is the log likelihood . We chose this loss function over mean squared error or mean absolute error because it seemed like the more logical choice when trying to make a real number match a target integer. Early experiments showed that directly optimizing for mean absolute error had roughly the same result as optimizing in this way and measuring . We train using the Adam optimizer with a constant learning rate of for batches each with batch size .
b.3 Solving maximum likelihood problems for mixture of Gaussians
b.3.1 Details for 2D synthetic mixtures of Gaussians experiment
We generated the datasets according to the following generative process.
Generate the number of data points, .
Generate cluster labels.
Generate data from spherical Gaussian.
Table 10 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 11 summarizes the detailed results with various number of inducing points in the ISAB. Figure 3 shows the actual clustering results based on the predicted parameters.
|rFF + Pooling||-2.0006 0.0123||-1.6186 0.0042|
|SAB + Pooling||-1.6772 0.0066||-1.5070 0.0115|
|ISAB (16) + Pooling||-1.6955 0.0730||-1.4742 0.0158|
|ISAB (32) + Pooling||-1.6353 0.0182||-1.4681 0.0038|
|ISAB (64) + Pooling||-1.6349 0.0429||-1.4664 0.0080|
|rFF + PMA||-1.6680 0.0040||-1.5409 0.0037|
|Set Transformer||-1.5145 0.0046||-1.4619 0.0048|
|Set Transformer (16)||-1.5009 0.0068||-1.4530 0.0037|
|Set Transformer (32)||-1.4963 0.0064||-1.4524 0.0044|
|Set Transformer (64)||-1.5042 0.0158||-1.4535 0.0053|
b.3.2 Details for CIFAR-100 meta clustering experiment
We pretrained VGG net (Simonyan & Zisserman, 2014) with CIFAR-100, and obtained the test accuracy 68.54%. Then, we extracted feature vectors of 50k training images of CIFAR-100 from the 512-dimensional hidden layers of the VGG net (the layer just before the last layer). Given these feature vectors, the generative process of datasets is as follows.
Generate the number of data points, .
Uniformly sample four classes among 100 classes.
Uniformly sample data points among four sampled classes.
Table 12 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 13 summarizes the detailed results with various number of inducing points in the ISAB.
|rFF + Pooling||0.5593 0.0149||0.5693 0.0171|
|SAB + Pooling||0.5831 0.0341||0.5943 0.0337|
|ISAB (16) + Pooling||0.5672 0.0124||0.5805 0.0122|
|ISAB (32) + Pooling||0.5587 0.0104||0.5700 0.0134|
|ISAB (64) + Pooling||0.5586 0.0205||0.5708 0.0183|
|rFF + PMA||0.7612 0.0237||0.7670 0.0231|
|Set Transformer||0.9015 0.0097||0.9024 0.0097|
|Set Transformer (16)||0.9210 0.0055||0.9223 0.0056|
|Set Transformer (32)||0.9103 0.0061||0.9119 0.0052|
|Set Transformer (64)||0.9141 0.0040||0.9153 0.0041|
b.4 Meta Set Anomaly
Table 14 describes the architecture for meta set anomaly experiments. We trained all models via Adam optimizer with learning rate and exponential decay of learning rate for 1,000 iterations. 1,000 datasets subsampled from CelebA dataset (see Figure 4) are used to train and test all the methods. We split 800 training datasets and 200 test datasets for the subsampled datasets.
b.5 Point Cloud Classification
We used the ModelNet40 dataset for our point cloud classification experiments. This dataset consists of a 3-dimensional representation of 9,843 training and 2,468 test data which each belong to one of object classes. As input to our architectures, we produce point clouds with points each (each point is represented by coordinates). For generalization, we randomly rotate and scale each set during training.
We show results our architectures in Table 15 and additional experiments which used points in Table 16. We trained using the Adam optimizer with an initial learning rate of which we decayed by a factor of every steps.
|rFF + Pooling||0.7951 0.0166|
|rFF + PMA (1)||0.8076 0.0160|
|ISAB (16) + Pooling||0.8273 0.0159|
|Set Transformer (16)||0.8454 0.0144|
|rFF + Pooling + tricks (Zaheer et al., 2017)||0.82 0.02|