Learning Permutation Invariant Representations using Memory Networks
Abstract
Many real world tasks such as 3D object detection and high-resolution image classification involve learning from a set of instances. In these cases, only a group of instances, a set, collectively contains meaningful information and therefore only the sets have labels, and not individual data instances. In this work, we present a permutation invariant neural network called a Memory-based Exchangeable Model (MEM) for learning set functions. The model consists of memory units which embed an input sequence to high-level features (memories) enabling the model to learn inter-dependencies among instances of the set in form of attention vectors. To demonstrate its learning ability, we evaluated our model on test datasets created using MNIST, point cloud classification, and population estimation. We also tested the model for classifying histopathology whole slide images to discriminate between two subtypes of Lung cancer—Lung Adenocarcinoma, and Lung Squamous Cell Carcinoma. We systematically extracted patches from lung cancer images from The Cancer Genome Atlas (TCGA) dataset, the largest public repository of histopathology images. The proposed method achieved a competitive classification accuracy of 84.84%. The results on other datasets are promising and demonstrate the efficacy of our model.
1 Introduction
![]() |
![]() |
Deep††Authors have contributed equally. artificial neural networks have achieved impressive performance for representation learning. The majority of these deep architectures take a single instance as an input for learning tasks. However, we often need to learn representations for unordered sequential data, or exchangeable sequences. Exchangeable data arise in many practical scenarios, and can particularly be used in Multiple Instance Learning (MIL) where a label is associated with a set, instead of a single input. An example of MIL would be classification of high resolution histopathology images, known as whole-slide images (WSIs). These images are extremely high resolution 50,000 50,000 pixels, with labels for an entire image instead of patch-level annotations. Patches could be extracted from these WSIs using various histological features. A set of patches from a WSI can be used for set classification/regression tasks; one such example of LUAD/LUSC classification is explained in Figure 1. Recurrent Neural Networks (RNNs) are popular approach to learn representation from sequential ordered instances. However, the lack of permutation invariance renders RNNs ineffective for exchangeable sequences.
Exchangeable models are invariant to permutations of individual instances, i.e., the output of a network remains the same for all permutations of its input set. For example, finding a maximum value in a set of numbers constitutes such a case. Zaheer et al. [28] proposed an exchangeable model, called Deep Sets, for learning set functions. They proved that a pooling operation in a latent space can approximate any arbitrary set function. This implies that our method could an universal approximation of any set function.
In this paper, we propose a novel architecture for exchangeable sequences incorporating attention over the instances to learn inter-dependencies. Our main contribution is a sequence-to-sequence permutation invariant layer called the Memory Block. The proposed model uses series of connected memory block layers, to model complex dependencies within an input set using self attention mechanism. Our model achieves state-of-the-art performance for LUAD/LUSC classification of WSIs with 84.84% accuracy.
2 Related Work
In statistics, exchangeability has been long studied. De Fenetti studied exchangeable random variables and showed that sequence of infinite exchangeable random variables can be factorised into some independent and identically distributed mixtures conditioned on some parameter . Bayesian sets [6] introduced a method to model exchangeable sequences of binary random variables by analytically computing the integrals in de Fenitti’s theorem. Orbanz et. al. [16] used de Fenetti’s theorem for Bayesian modelling of graphs, matrices, and other data that can be modeled by random structures. Considerable work has also been done on partially exchangeable random variables [1].
Symmetry in neural networks was first proposed by Shawe et al. [20] under the name Symmetry Network. They proposed that invariance can be achieved by weight-preserving automorphisms of a neural network. Ravanbaksh et al. proposed a similar method for equivariance network through parameter sharing [18]. Bloem Reddy et al. [3] studied the concept of symmetry and exchangeability for neural networks in detail and established a link between functional and probabilistic symmetry, and obtained generative functional representations of joint and conditional probability distributions that are invariant or equivariant under the action of a compact group. Zhou et. al. [30] proposed treating instances in a set as non identical and independent samples for multi instance problem.
Most of the work published in recent years have focused on ordered sets. Vinyals et. al. introduced Order Matter: Sequence to Sequence for Sets in 2016 to learn a sequence to sequence mapping. Many related models and key contributions have been proposed that uses the idea of external memories like RNNSearch [2], Memory Networks [25, 23] and Neural Turing Machines [8]. Recent interest in exchangeable models was developed due to their application in MIL. Deep Symmetry Networks [5] used kernel-based interpolation to tractably tie parameters and pool over symmetry spaces of any dimension. Deep Sets [28] by Zaheer et al. proposed a permutation invariant model. They proved that any pooling operation (mean, sum, max or similar) on individual features is a universal approximator for any set function. They also showed that any permutation invariant model follows de Fenitti’s theorem.Work has also been done on learning point cloud classification which is an example of MIL problem. Deep Learning with Sets and Point Cloud [17] used parameter sharing to get a equivariant layer. Another important paper on exchanegable model is Set Transformer. Set Transformer [14] by Lee et al. used results from Zaheer et al. [28] and proposed a Transformer [23] inspired permutation invariant neural network. The Set Transformer uses attention mechanisms to attend to inputs in order to invoke activation. Instead of using averaging over instances like in Deep Sets, the Set Transformer uses a parametric aggregating function pool which can adapt to the problem at hand. Another way to handle exchangeable data is to modify RNNs to operate on exchangeable data. BRUNO [13] is a model for exchangeable data and makes use of deep features learned from observations so as to model complex data types such as images. To achieve this, they constructed a bijective mapping between random variables in the observation space and features , and explicitly define an exchangeable model for the sequences . Deep Amortized Clustering [15] proposed using Set Transformers to cluster sets of points with only few forward passes. Deep Set Prediction Networks [29] introduced an interesting approach to predict sets from a feature vector which is in contrast to predicting an output using sets.
3 Background
In this section, we explain the general concepts of exchangeability, its
relation to de Fenetti’s theorem, and briefly discuss Memory Networks.
Exchangeable Sequence: A sequence of random variables
is exchangeable if the joint probability of the distribution
does not change on permutation of indices . Mathematically, if
for a permutation
function , then the sequence is exchangeable.
Exchangeable Models: A model is said to be exchangeable if the output of the model is invariant to the permutation of its inputs. Exchangeability implies that the information provided by each instance is independent of the order in which they are represented. Exchangeable models can be of two types depending on the application : i) permutation invariant, and ii) permutation equivariant.
A model represented by a function where is a set, is said to be permutation equivariant if permutation of input instances permutes the output labels with the same permutation . Mathematically, a permutation-equivariant model is represented as,
Similarly, a function is permutation invariant if permutation of input instances does not change the output of the model. Mathematically,
Deep Sets [28] incorporate a permutation-invariant model to
learn arbitrary set functions by pooling in a latent space. The authors further
showed that any pooling operation such as averaging and max on individual
instances of a set can be used as a universal approximator for any arbitrary set
function. The authors proved that exchangeability implies the
following two theorems.
Theorem 1: A function operating on a set
,…., having elements from a countable universe, is a valid set
function, i.e., invariant to the permutation of instances in , if it can be
decomposed to , for any suitable transformations
and .
Theorem 2: Assume the elements are from a compact set in
, i.e., possibly uncountable, and the set size is fixed to .
Then any continuous function operating on a set , i.e., which is permutation invariant to the elements in
can be approximated arbitrarily close in the form of .
Theorem 1 is linked to de Finetti’s theorem, which states that a random infinitely exchangeable sequence can be factorised into mixture densities conditioned on some parameter which captures the underlying generative process i.e..
(1) |
Memory Networks: The idea of using an external memory for relational learning tasks was introduced by Weston et al. [25]. Later, an end-to-end trainable model was proposed by Sukhbaatar et al. [22]. Memory networks enable learning of dependencies among instances of a set. Memory network embeds the sequential data by providing an explicit memory representation for each instance in the sequence.
Taking an entire sequence and representing it as a single input offers the above-mentioned advantages with respect to memory, but of course, it has challenges. One of the main difficulties is that the order in the sequence is lost, as well as the local proximity of sequence elements, which is crucial for most sequential inputs.
4 Mem
![]() |
![]() |
This section discusses the motivations, components, and offers an analysis of the proposed Memory-based Exchangeable Model (MEM)—a neural network designed to work with infinitely exchangeable sequences.
4.1 Motivation
In order to learn efficient representation for a set of instances, it is important to focus on instances which are highly correlated with the rest of the set, i.e., we need to attend to specific instances more than other instances. We therefore use the memory network to learn attention mapping for each instance. Memory networks are conventionally used for NLP for mapping questions posted in natural language to an answer [25, 22]. We exploit the idea of having memories which can learn key features shared by one or more instances. Through these key features or memories, the model can learn inter-dependencies. As inter-dependencies are learnt, a set can be condensed into a compact vector such that a MLP can be used for a classification or regression learning.
4.2 Components
MEM is composed of three sequentially connected units: i) a feature extraction
model, ii) memory blocks, and iii) fully connected layers to predict the output.
A memory block is the main component of MEM. It learns a permutation
invariant representation of a given input sequence. Multiple memory blocks can
be stacked together for modeling complex relationships and dependencies in
exchangeable data.
A memory block is a sequence-to-sequence model, i.e., it
transforms a given input sequence to another representative
sequence . The output sequence is invariant
to the element-wise permutations in the input sequence. A memory block contains
number of memory units. Each memory unit takes sequential data and
produces a probability distribution over its input (known as attention
vectors). The attention vectors from all memory units are transformed into the
final output sequence of a given memory block. The schematic diagram of a memory
block is shown in (b).
A memory unit transforms a given input sequence to an attention vector. The higher attention value represents the higher “importance” to the corresponding element of the input sequence. Essentially, it captures the relationships among different elements of the input. Multiple memory units enable the memory block to capture many complex dependencies and relationships among the elements. The memory unit embeds the element of the input to a -dimensional memory vector using an embedding matrix , as follows:
where is some non-linearity. The memory vectors are stacked to form a matrix of the shape . The relative degree of correlations among the memory vectors are computed using cross-correlation followed by a softmax (at the ) and then taking an average over the individual instances:
(2) |
The is the final output of the memory unit, which is an attention vector over representing the correlations among the elements when projected to the embedding space.
The purpose of memory unit is to embed feature vectors into another space that
could corresponds to a distinct “attribute” or “characteristics” of
instances. The cross correlation or the calculated attention vector represents
the instances which are highly suggestive of those “attributes” or
“charactersistic”.
The final output sequence of a memory block is computed as a weighted sum of with the probability distributions from all the memory units. The matrix is a bijective transformation of learned using an autoencoding neural network. Each memory block has its own autoencoder model to learn this bijective mapping. The element of the output sequence is given by
where, is a bijective transformation of learned with an auto-encoder and
is the output of memory unit as given by (2).
The bijective transformation from enables equivariant correspondence between the elements of the two sequences and . It enables maintaining the same level of expressiveness or information in as that of . We used a deep auto-encoder as a bijective transformation function:
where, and are encoder and decoder networks, respectively.

4.3 Analysis
This section discusses the mathematical properties of our model. We use theorems from
Deep Sets [28] to prove that our model is permutation invariant
and universal approximator for arbitrary set functions.
Property 1: Memory units are permutation equivarent.
Consider an input set and each memory unit is represented by a
transformation where is obtained by averaging over
column of cross-correlation matrix. Since cross-correlation is a permutation
euqivariant operation, is a permutation equivariant
transformation.
Property 2: Memory Blocks are permuatation
invaraint. Each memory block layers consists of multiple memory units
represented by transformation function . The output of the kth memory
block layer consisting of m block units is a sequence
consisting of m instances. Mathematically, it can be written as
where, is a bijective
transformation. Since both and transformations are permutation
equivariant, their summation will be permuation invariant.
Property 3: MEM is a universal approximator of any arbitrary function. The output of our model can be written as where represents the transformation of the subsequent memory block layer. Following Theorem 2. in section 3, MEM can be used as universal approximator for a set function.
4.4 Proposed Model
-
Each element of a given input sequence is passed through a feature extraction model to produce a sequence of feature vectors . The feature extraction model can be a Convolutional Neural Network (CNN), or a Multi-Layer Perceptron (MLP), or any other type of differentiable feature extractor.
-
The feature sequence is then passed through a memory block to obtain another sequence which is a permutation-invariant representation of the input sequence.
-
The number of elements in , i.e., and dimensionality of each element of , i.e., are defined by the hyper-parameters of memory blocks.
-
The multiply memory blocks can be stacked in series. The output from the last memory block is either vectorized or pooled, which is subsequently passed to a MLP layer for classification or regression. There are different configurations that arise due to vectorization and pooling operations at the final memory block. These different configurations are discussed in the next section.
5 Model Evaluation
We performed two series of experiments comparing MEM against the simple pooling operations. In the first series of experiments, we established the learning ability of the proposed model using toy datasets. For the second experiment series, we used our model for classification of subtypes of lung cancer against the largest public dataset of histopathology whole slide images (WSIs) [24].
We experimented with different choices of pooling operations—max, mean, dot product, and sum. Our model also has a special pooling , i.e., a memory block with a single memory unit. Therefore, we tested 9 different models for each experiment—five configurations of our model, and four configurations with just pooling operations.
5.1 Experiment Series 1: Toy Datasets
|
||||||||
Sum of Even | Sum is | Counting Unique | Maxima of | Gaussian | ||||
Methods | Digits | Prime | Images | Set | Clustering | |||
Accuracy | MAE | Accuracy | Accuracy | MAE | Accuracy | MAE | NLL | |
FF + MEM + MB1 (ours) | 0.9367 0.0016 | 0.2516 0.0105 | 0.9438 0.0043 | 0.7108 0.0084 | 0.3931 0.0080 | 0.9326 0.0036 | 0.1449 0.0068 | 1.348 |
FF + MEM + Mean (ours) | 0.9355 0.0015 | 0.2437 0.0087 | 0.7208 0.0217 | 0.4264 0.0062 | 0.9525 0.0109 | 0.9445 0.0035 | 0.1073 0.0067 | 1.523 |
FF + MEM + Max (ours) | 0.9431 0.0020 | 0.2295 0.0098 | 0.9361 0.0060 | 0.6888 0.0066 | 0.4140 0.0079 | 0.9498 0.0022 | 0.1086 0.0060 | 1.388 |
FF + MEM + Dotprod (ours) | 0.8411 0.0045 | 0.3932 0.0065 | 0.9450 0.0086 | 0.7284 0.0055 | 0.3664 0.0037 | 0.9517 0.0041 | 0.0999 0.0097 | 1.363 |
FF + MEM + Sum (ours) | 0.9353 0.0022 | 0.2739 0.0081 | 0.6652 0.0389 | 0.3138 0.0094 | 1.3696 0.0151 | 0.9430 0.0031 | 0.1318 0.0058 | 1.611 |
FF + Mean | 0.9159 0.0019 | 0.2958 0.0049 | 0.5280 0.0078 | 0.3140 0.0071 | 1.2169 0.0136 | 0.3223 0.0075 | 1.0029 0.0155 | 2.182 |
FF + Max | 0.6291 0.0047 | 1.3292 0.0211 | 0.9257 0.0033 | 0.7088 0.0060 | 0.3933 0.0059 | 0.9585 0.0012 | 0.0742 0.0032 | 1.608 |
FF + Dotprod | 0.1503 0.0015 | 1.8015 0.0016 | 0.9224 0.0028 | 0.7254 0.0063 | 0.3726 0.0054 | 0.9548 0.0017 | 0.1355 0.0027 | 8.538 |
FF + Sum | 0.6333 0.0043 | 0.5763 0.0069 | 0.5264 0.0050 | 0.2982 0.0042 | 1.3415 0.0169 | 0.3344 0.0038 | 0.9645 0.0111 | 12.05 |
|
To demonstrate the advantage of MEM over simple pooling operations, we consider
four toy problems, involving regression or classification over sets. We
assembled these toy problems using simple computer vision datasets.
Sum of Even Digits is a regression problem over the set of
images containing handwritten digits from MNIST. For a given set of images , the goal is to find the sum of all even digits. We used the
mean absolute error as the loss function, where is the required
sum and is the model’s prediction. We split the MNIST dataset into 70-30%
training, and testing data-sets, respectively. We sampled 100,000 sets of 2 to
10 images from the training data. We trained each
model for 100 epochs using the SGD optimizer. For testing, we sampled 10,000
sets of images containing number of images per set where .
The Figure 4 shows the performance of MEM against simple pooling
operations with respect to the number of images in the set.

Sum of Prime is a classification problem over a set of MNIST
images. A set is labeled positive if it contains any two digits such that their
sum is a prime number. We constructed the dataset by randomly sampling five
images from the MNIST dataset. This example requires model to predict the label
in presence of two digits only, therefore attention plays an important role. We
constructed the training data with 20,000 sets randomly sampled from the
training data of MNIST. For testing, we randomly sampled 5,000 sets from the
testing data of MNIST. We used binary cross entropy loss with Adam
optimizer [12] and trained the network for 100 epochs. The
results are reported in the second column of Table 1 that
shows the robustness of memory block.
Maxima of a Set is a regression problem to predict the
highest digit present in the set of images from MNIST. We constructed a set of
five images by randomly selecting samples from MNIST dataset. The label for each
set is the largest number present in the set. For example, images of is labeled as . Then we constructed 20,000 training sets and for
testing we randomly sampled 5,000 sets each time. We used MSE loss with Adam
optimizer and trained the model for 200 epochs. We used the linear activation as
the last layer of our model. We obtained the predicted maxima by rounding the
model’s prediction. The detailed comparison of accuracy and MAE between
different models is given in the second last column of
Table 1. We found that FF+Max learns the identity mapping
and thus results in a very high accuracy. In all the training sessions, we
consistently got the training accuracy of 100% for the FF+Max configuration,
whereas MEM has much less discrepancy between train and test accuracy values.
Counting of Unique Images is another regression problem over a set. It involves counting the unique objects in set of images from fashion MNIST dataset [27]. We constructed the training data by selecting a set, as follows:
-
Randomly select an integer between 2 and 10.
-
Randomly select another integer between 1 and .
-
Select number of unique objects from fashion-MNIST training data.
-
Then add - number of randomly selected objects selected in the previous step.
The task is to count unique objects in a given set. We used
softplus activation for the last layer, defined as . We maximized the Poisson likelihood to train the model. The model was run
for 200 epochs using SGD optimizer. The results achieved for MEM and feature
pooling is similar to feature pooling, see the third column
of Table 1.
Amortized Gaussian Clustering is a regression problem that involves estimating the parameters of a population of Mixture of Gaussian (MoG). Similar to Set Transformer [14], we test our model for learning parameters of Gaussian Mixture with components such that the likelihood of the observed samples is maximum. This is in contrast to EM algorithm which updates parameters of the mixture recursively until the stopping criterion is satisfied. Instead, we use MEM to directly predict parameters of MoG i.e.
For simplicity we sample from MoG with only four components. The Generative process for each training dataset is as follows
-
Mean of each gaussian is selected from a uniform distribution i.e. .
-
Select a cluster for each instance in the set, i.e.,
-
Generate data from an univariate Gaussian .
We created a dataset of 20,000 sets each consisting of 500 points sampled from different MoGs. We use Negative Log Likelihood (NLL) as the loss function and use Adam optimizer to minimise the NLL loss. Results in Table 1 shows that MEM is significantly better than feature pooling.
5.2 Experiment Series 2: Real-world Datasets
To show the robustness and scalability of the model for the real-world problems, we have validated MEM on two larger datasets. Firstly, we tested our model on a point cloud dataset for predicting the object type from the set of 3D coordinates. Secondly, we used the largest public repository of histopathology images (TCGA) [24] to differentiate between two main sub-types of lung cancer. Without any significant effort in extracting histologically relevant features and fine-tuning, we achieved remarkable accuracy value 84.84% on the 5-fold validation.
5.3 Point Cloud Classification
We evaluated MEM on a more complex classification task using ModelNet40 [26] point cloud dataset. The dataset consists of 40 different objects or classes embedded in a three dimensional space as points. We produce point-clouds with 100 and 1000 points each (x, y, z-coordinates) from the mesh representation of objects using the point-cloud library’s sampling routine [19]111We obtained the training and test datasets from Zaheer et al. [28]. We trained our model with categorical cross-entropy loss with Adam optimizer along with batch normalization [10] and dropout [21]. We compare the performance against Deep Sets [28] and Set Transformer [14]. We experimented with different configurations of our model and found that FF+MB1 works best for 100 points cloud classification.
We achieve classification accuracy of 85.21% using 100 points. Our model performs better than Deep Sets and Set Transformer for 100 points which shows the effectiveness of having attention from memories. For 1000 points we achieved the accuracy value of 87.3% which is comparable to Deep Sets. The major focus of our approach is to rely on inter-dependencies and “attention” among instances to form a representation of a set. For 100 points, our model perform better as inter-dependencies play a major role than for 1000 points. When there are lot of points, the information might be contained already as whole and relational representations do not contribute much. The comparison of our model with Deep Sets and Set Transformer is shown in Table 2.
|
||
Configuration | 100 pts | 1000 pts |
Deep set | 0.8200 | 0.8700 |
Set Transformer | 0.8454 | 0.8915 |
Ours (FF + MEM + MB1) | 0.8521 | 0.8730 |
|
|
|
5.4 Lung Cancer Subtype Classification
Lung Adenocarcinoma (LUAD) and Lung Squamous Cell Carcinoma (LUSC) are two main types of non-small cell lung cancer (NSCLC) that account for 65-70% of all lung cancers [7]. Classifying patients accurately is important for prognosis and therapy decisions. Automated classification of these two main subtypes of NSCLC is a crucial step to build computerized decision support and triaging systems. We present a two-staged method to differentiate LUAD and LUSC for whole slide images, short WSIs, that are very large images. Firstly, we implement a method to systematically sample patches/tiles from WSIs. Next, we extract image features from patches using Densenet [9]. We arrange image features in sets (per WSI basis) and train MEM to classify the two subtypes. The highest 5-fold classification accuracy score achieved is 84.84%.
To the best of our knowledge, this is the first ever study conducted on all the lung cancer slides in TCGA dataset (comprising of 2 TB of data consisting of 2.5 million patches of size 10001000 pixels). All research works in literature use a subset of the slides with their own test-train split instead of cross validation, making it difficult to compare against them. However, we have achieved greater than or similar to all existing research works without utilizing any expert’s opinions (pathologists) or domain-specific techniques.
We downloaded 2,580 WSIs from TCGA public repository [24] with 1,249, and 1,331 slides for LUAD and LUSC, respectively. To train our model, each WSI is converted to a set of features. These features correspond to deep features extracted from a set of representative patches from a given WSI. We process each WSI as follows:
-
Tissue Extraction – Every WSI contains a bright background that generally contains irrelevant (non-tissue) pixel information. We removed non-tissue regions using color thresholds.
-
Selecting Representative Patches – Segmented tissue is now divided into patches. All the patches are then grouped into a pre-set number of categories (classes) via a clustering method. A 10% of all clustered patches are selected uniformly distributed within each class to assemble representative patches. Six of these representative patches for each class (LUAD and LUSC) is shown in Figure 5.
-
Feature Set – A set of features for each WSI is created by converting its representative patches into image features. We use DenseNet [9] as the feature extraction model. There are different number of features for each WSI.
|
|
Configuration | Accuracy |
FF + MEM + Sum (ours) | 0.8484 0.0210 |
FF + MEM + Mean (ours) | 0.8465 0.0225 |
FF + MEM + MB1 (ours) | 0.8457 0.0219 |
FF + MEM + Dotprod (ours) | 0.6345 0.0739 |
FF + sum | 0.5159 0.0120 |
FF + mean | 0.7777 0.0273 |
FF + dotprod | 0.4112 0.0121 |
|
The results are shown in Table 3. We achieved the maximum accuracy of 84.84% with FF + MEM + Sum configuration. It is difficult to compare our approach against other approaches in literature due to non-standardization of the dataset. Coudray et al. [4] used the TCGA dataset with around 1,634 slides to classify LUAD and LUSC. They achieved AUC of 0.947 using patches at 20. We achieved the similar AUC of for one of the folds and the averaged AUC of . It is important to note that we did not do any fine-tuning or utilized any form of inputs from an expert/pathologist. Instead, we extracted diverse patches and let the model learn to differentiate between two sub-types by “attending” relevant ones. Another study by Jaber et al. [11] uses cell density maps, achieving an accuracy value of 83.33% and AUC of 0.9068. However, they used much smaller portion of TCGA, i.e., 338 TCGA diagnostic WSIs (164 LUAD and 174 LUSC) were used to train, and 150 (71 LUAD and 79 LUSC).
6 Conclusion
In this paper, we introduced Memory-based Exchangeable Model (MEM) for learning permutation invariant representations. The proposed method uses attention mechanisms over “memories” (higher order features) for modelling complicated interactions among elements of a set. Typically for MIL, instances are treated as independently and identically distributed. However, instances are rarely independent in real tasks, and we overcome this limitation using “attention” mechanism in memory units, that exploits relations among instances. We also provided some theoretical properties of our model, including the fact that it is a universal approximator for permutation invariant functions. We achieved good performance on all problems that requires exploiting instance relationships. Our model scales well on real world problems as well, achieving accuracy score of 84.84% on classifying lung cancer subtypes on the largest public repository of histopathology images.
Acknowledgements – The funds for this research have been provided by the ORF-RE program (Ontario Research Fund - Research Excellence). Core research was also supported by NSERC (Natural Sciences and Engineering Research Council of Canada). The first author’s Ph.D is supported by MITACS (Mathematics of Information Technology and Complex Systems), and second author’s MASc. is funded by Vector Institute. The GPU cluster was graciously provided by Prof Graham Taylor allowing to run the experiments for the study. Authors are thankful to Sultaan Shah at Huron Digital Pathology, Canada, for writing C/C++ code for faster extraction of patches and features from histopathology images.
References
- [1] D. J. Aldous. Representations for partially exchangeable arrays of random variables. Journal of Multivariate Analysis, 11(4):581–598, 1981.
- [2] D. Bahdanau, K. Cho, and Y. Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.
- [3] B. Bloem-Reddy and Y. W. Teh. Probabilistic symmetry and invariant neural networks. arXiv preprint arXiv:1901.06082, 2019.
- [4] N. Coudray, P. S. Ocampo, T. Sakellaropoulos, N. Narula, M. Snuderl, D. Fenyö, A. L. Moreira, N. Razavian, and A. Tsirigos. Classification and mutation prediction from non–small cell lung cancer histopathology images using deep learning. Nature medicine, 24(10):1559, 2018.
- [5] R. Gens and P. M. Domingos. Deep symmetry networks. In Advances in neural information processing systems, pages 2537–2545, 2014.
- [6] Z. Ghahramani and K. A. Heller. Bayesian sets. In Advances in neural information processing systems, pages 435–442, 2006.
- [7] S. Graham, M. Shaban, T. Qaiser, N. A. Koohbanani, S. A. Khurram, and N. Rajpoot. Classification of lung cancer histology images using patch-level summary statistics. In Medical Imaging 2018: Digital Pathology, volume 10581, page 1058119. International Society for Optics and Photonics, 2018.
- [8] A. Graves, G. Wayne, and I. Danihelka. Neural turing machines. arXiv preprint arXiv:1410.5401, 2014.
- [9] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q. Weinberger. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4700–4708, 2017.
- [10] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167, 2015.
- [11] M. I. Jaber, L. Beziaeva, C. W. Szeto, J. Elshimali, S. Rabizadeh, and B. Song. Automated adeno/squamous-cell nsclc classification from diagnostic slide images: A deep-learning framework utilizing cell-density maps, 2019.
- [12] D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
- [13] I. Korshunova, J. Degrave, F. Huszár, Y. Gal, A. Gretton, and J. Dambre. Bruno: A deep recurrent model for exchangeable data. In Advances in Neural Information Processing Systems, pages 7190–7198, 2018.
- [14] J. Lee, Y. Lee, J. Kim, A. R. Kosiorek, S. Choi, and Y. W. Teh. Set transformer. CoRR, abs/1810.00825, 2018.
- [15] J. Lee, Y. Lee, and Y. W. Teh. Deep amortized clustering. arXiv preprint arXiv:1909.13433, 2019.
- [16] P. Orbanz and D. M. Roy. Bayesian models of graphs, arrays and other exchangeable random structures. IEEE transactions on pattern analysis and machine intelligence, 37(2):437–461, 2014.
- [17] S. Ravanbakhsh, J. Schneider, and B. Poczos. Deep learning with sets and point clouds. arXiv preprint arXiv:1611.04500, 2016.
- [18] S. Ravanbakhsh, J. Schneider, and B. Poczos. Equivariance through parameter-sharing. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 2892–2901. JMLR. org, 2017.
- [19] R. Rusu and S. Cousins. 3d is here: Point cloud library (pcl). In Robotics and Automation (ICRA), 2011 IEEE International Conference on, pages 1 –4, May 2011.
- [20] J. Shawe-Taylor. Symmetries and discriminability in feedforward network architectures. IEEE Transactions on Neural Networks, 4(5):816–826, 1993.
- [21] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1):1929–1958, 2014.
- [22] S. Sukhbaatar, J. Weston, R. Fergus, et al. End-to-end memory networks. In Advances in neural information processing systems, pages 2440–2448, 2015.
- [23] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin. Attention is all you need. In Advances in neural information processing systems, pages 5998–6008, 2017.
- [24] J. N. Weinstein, E. A. Collisson, G. B. Mills, K. R. M. Shaw, B. A. Ozenberger, K. Ellrott, I. Shmulevich, C. Sander, J. M. Stuart, C. G. A. R. Network, et al. The cancer genome atlas pan-cancer analysis project. Nature genetics, 45(10):1113, 2013.
- [25] J. Weston, S. Chopra, and A. Bordes. Memory networks. arXiv preprint arXiv:1410.3916, 2014.
- [26] Z. Wu, S. Song, A. Khosla, F. Yu, L. Zhang, X. Tang, and J. Xiao. 3d shapenets: A deep representation for volumetric shapes. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1912–1920, 2015.
- [27] H. Xiao, K. Rasul, and R. Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747, 2017.
- [28] M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. R. Salakhutdinov, and A. J. Smola. Deep sets. In Advances in neural information processing systems, pages 3391–3401, 2017.
- [29] Y. Zhang, J. Hare, and A. Prügel-Bennett. Deep set prediction networks. arXiv preprint arXiv:1906.06565, 2019.
- [30] Z.-H. Zhou, Y.-Y. Sun, and Y.-F. Li. Multi-instance learning by treating instances as non-iid samples. In Proceedings of the 26th annual international conference on machine learning, pages 1249–1256. ACM, 2009.