Motivation: While deep learning has achieved great success in computer vision and other fields, currently it does not work well on genomic data due to ‘‘big , small ’’ problem (i.e., relatively small number of samples with high-dimensional features). In order to make deep learning work with a small amount of training data, we have to design new models that can facilitate few-shot learning. In this paper we focus on developing data efficient deep learning models that learn from a limited number of training examples and generalize well.
Results: We developed two deep learning modules: feature attention layer and k-Nearest-Neighbor (kNN) attention pooling layer to make our model much more data efficient than conventional deep learning models. Feature attention layer can directly select important features that are useful for patient classification. kNN attention pooling layer is based on graph attention model, and is good for semi-supervised few-shot learning. Experiments on both synthetic data and cancer genomic data from TCGA projects show that our method has better generalization power than conventional neural network model.
Availability: We have implemented our method using PyTorch deep learning framework (https://pytorch.org). The code is freely available at https://github.com/BeautyOfWeb/AffinityNet.
AffinityNet]AffinityNet: semi-supervised few-shot learning for disease type prediction Ma et al.]Tianle Ma , and Aidong Zhang
Patients, drugs, networks, etc., are all complex objects with heterogeneous features or attributes. Complex object clustering and classification are ubiquitous in real world applications. For instance, it is important to cluster cancer patients into subgroups and identify disease subtypes in cancer genomics (Shen et al., 2012; Wang et al., 2014; Ma and Zhang, 2017). Compared with images, which have homogeneous structured features (i.e., pixels are arranged in a 3-D array as raw features), complex objects usually have heterogeneous features with unclear structures. Deep learning models such as Convolutional Neural Networks (CNNs) widely used in computer vision (LeCun et al., 2015; Krizhevsky et al., 2012) and other fields (Bahdanau et al., 2014; Sutskever et al., 2014; Silver et al., 2016; Banino et al., 2018) cannot be directly applied to complex objects whose features are not ordered structurally.
One critical challenge in cancer patient clustering problem is ‘‘big , small ’’ problem: we only have a small number of samples compared with high-dimensional features each sample has. In other words, we do not have an ‘‘ImageNet’’(Russakovsky et al., 2015) to train deep learning models that can learn good representations from raw features. Moreover, unlike pixels in images, patient features such as gene expressions are much noisier and more heterogeneous. These features are not ‘‘naturally’’ ordered. Thus we cannot directly use convolutional neural networks with small filters to extract abstract local features.
Though the number of features (i.e., variables or covariates) in genomic data is usually very high, many features may be irrelevant to a specific task. For instance, a disease may only have a few risk factors involving a small number of features. In order to facilitate feature selection in a ‘‘deep learning’’ way, we propose a feature attention module that can be incorporated into a neural network model and directly learn feature weights using backpropagation.
For a clustering/classification task, nodes belonging to the same cluster should have similar representations that are near the cluster centroid. Based on this intuition we develop k-nearest-neighbor (kNN) attention pooling layer, which is based on Graph Attention Model (GAM) (Veličković et al., 2017) and can facilitate semi-supervised learning with only a few labeled training examples. Both feature attention layer and kNN attention pooling layer can be plugged into a neural network model just as convolutional layers or pooling layers in deep learning models.
Similar to graph attention model (GAM) (Veličković et al., 2017), we proposed a model called AffinityNet, which consists of stacked feature attention and kNN attention pooling layers, to facilitate semi-supervised few-shot learning (i.e., training model with few labeled examples). Since GAM is designed to tackle representation learning on graphs (Hamilton et al., 2017b), it does not apply to data without known graph (i.e., the links among nodes in a graph). AffinityNet generalizes GAM to facilitate representation learning on any collections of objects with or without known graph.
We performed experiments on both synthetic and real cancer genomics data. The results demonstrated that AffinityNet has better generalization power than conventional neural network model.
1.1 Related work
Feature attention layer can be seen as a special case of metric learning (Bellet et al., 2013). It is very simple and has largely been overlooked by deep learning community. In practice, we found it is useful to incorporate feature attention layer into neural network model and make the model generalize better when there are a large number of input features which are irrelevant or noisy.
kNN attention pooling layer is related to graph learning (Hamilton et al., 2017b; Kipf and Welling, 2016; Veličković et al., 2017), attention model (Vaswani et al., 2017; Veličković et al., 2017), pooling and normalization layers (Ioffe and Szegedy, 2015) in deep learning literature.
In graph learning, a graph has a number of nodes and edges (both nodes and edges can have features). When available, combining node features with graph structure can do a better job than using node features alone. For example, Graph Convolutional Neural Network (Kipf and Welling, 2016) incorporates graph structure (i.e. edges) into the learning process to facilitate semi-supervised few-shot learning. Graph attention model (GAM) (Veličković et al., 2017) learns a representation for each node based on the weighted pooling (i.e., attention) of its neighborhood in the given graph, and then performs classification using the learned representation. However, all these graph learning algorithms require that a graph (i.e., edges between nodes) is known. Many algorithms also require the input to be the whole graph, and thus do not scale well to large graphs. Our proposed AffinityNet model generalizes graph learning to a collections of objects (e.g., patients) without known graphs.
The key components of AffinityNet is kNN attention pooling layer, which is also related to normalization layers in deep learning, such as batch normalization (Ioffe and Szegedy, 2015), instance normalization (Jing et al., 2017), or layer normalization (Ba et al., 2016). All these normalization layers use batch statistics or feature statistics to normalize instance features, while kNN attention pooling uses attention mechanism to ‘‘normalize’’ instance features during training.
kNN attention pooling is different from existing max or average pooling layers used in deep learning models, where features in a local neighborhood are pooled to extract the signal and reduce feature dimensions. Our proposed kNN attention pooling layer applies pooling on nodes instead of features. kNN attention pooling layer combines normalization, attention and pooling, making it more general and powerful. It can be seen as an implicit regularizer to make the network generalize well for semi-supervised few-shot learning tasks.
2 Affinity Network Model
One key ingredient for the success of deep learning is its ability to learn a good representation (Bengio et al., 2013) through multiple complex nonlinear transformations. For classification tasks, the learned representation (usually the last hidden layer) is often linearly separable for different classes. If the output layer is a fully connected layer followed by softmax nonlinearity for classification, then the weight matrix for the last layer can be seen as class centroids in the transformed feature space. While conventional deep learning models often perform well when lots of training data is available, our goal is to design new models that learn a good feature transformation in a transparent and data efficient way.
Built upon the existing modules in deep learning toolbox, we propose feature attention and kNN (node) attention pooling layers, and use them to construct Affinity Network Model (‘‘AffinityNet’’). In a typical AffinityNet model as shown in Figure 1, the input layer is followed by a feature attention layer, and then followed by possibly multiple stacked kNN attention pooling layer (Figure 1 only illustrates one kNN attention pooling layer). The final output of kNN attention pooling layer will be the new learned network representations, which can be used for classification or regression tasks (for example, Cox model (Mobadersany et al., 2018)).
If we have a few labeled examples, we can use the last layer as the classifier (as shown in Figure 1) to train the model. Though it is possible to train AffinityNet with only labeled examples, it is more advantageous to use it as a semi-supervised learning framework (i.e., using both labeled and unlabeled data during training).
As the major components of AffinityNet is stacked feature attention and kNN attention pooling layers, we describe them in detail in the follow section.
2.1 Feature Attention Layer
Deep neural networks can learn good hierarchical local feature extractors (such as convolutional filters or inception modules (Szegedy et al., 2017)) automatically through backpropagation. Local feature operations such as convolution require features to be ordered structurally. For images or videos, pixels near each other naturally form a neighborhood. However, in other applications, features are not ordered and the structural relations among features are unknown (which is the focus of this paper). Thus we cannot directly learn a local feature extractor, instead we have to learn a feature selector that can select important features.
In many applications, there are redundant, noisy, or useless features, and the Euclidean distance between objects using all features may be dominated by irrelevant features (Bellet et al., 2013). However, with proper feature weighting, we can separate objects from different classes well. This is the intuition for feature attention layer, which is a special case of metric learning (Bellet et al., 2013).
Let be the feature vector of object , and be feature attention, satisfying
Before transformation, the learned distance between object and is (Eq. 2.1), which can be skewed by noisy and useless features. After transformation, the distance (Eq. 2.1) can be more informative for classification tasks.
2.1.1 Difference between fully connected layer
One might wonder if feature attention layer is useful, as a fully connected layer also has weights that can perform feature selection. Just like skip connection in ResNet (He et al., 2016) that can help gradient flow, feature attention layer can help select important features much easier than fully connected layer, and can increase the generalization power of a neural network model. Experiments on synthetic data (Sec. 3.1.1) clearly demonstrate the power of feature attention layer for selecting important features.
In addition, for fully connected layer without weight constraints, the weights can be negative and unbounded. Even if we set non-negativity constraint to weight matrix, the transformed features are linear combinations of input features. We cannot directly determine the importance of input features individually.
Feature attention layer only has parameter , which enforces an implicit constraint of learning a weighted Euclidean metric during training. This can prevent overfitting and increase generalization power. While fully connected layer has more representation power, it can easily overfit the training data and get struck in a local minima.
2.2 kNN attention pooling layer
A good classification model should have the ability to learn a feature transformation such that objects belonging to the same class have similar feature representations and are near the class centroid in the transformed feature space.
As an object’s nearest neighbors should have similar feature representations, we propose kNN attention pooling layer to incorporate neighborhood information using attention mechanism (Eq. 2.2):
In Eq. 2.2, and are input feature representations and transformed feature representations for object , respectively. represents the neighborhood of object . If a graph is given as in graph learning setting (Hamilton et al., 2017b), we can use the given graph to determine the neighborhood. However, in order to control computational complexity, we can set to a fixed small number even if a graph is given (Hamilton et al., 2017a). In a kNN attention pooling layer, is a hyperparameter that determines how many neighbors are used for calculating the representation of a node.
is a nonlinear transformation (e.g., layer followed by nonlinear activation). is the attention from object to object . is the attention kernel that will be discussed in the next section.
2.3 Attention kernels
Intuitively, if two objects are similar, their feature representations should be near each other. Objects belonging to the same class should be clustered together in the learned feature space. In order to achieve this, kNN attention pooling layer uses weighted pooling to ‘‘attract’’ similar objects together in the transformed feature space. Attention kernels essentially calculate such similarities among objects to facilitate weighted pooling.
There are many choices of attention kernels. For example:
In order to calculate a weighted average of new representations, we can use softmax function to normalize the attention (Other normalization is also feasible). So the normalized attention kernel is:
Now . For each node , we only select its neighbors for normalization. If the graph is not given, we can use attention kernel to calculate attentions among all nodes (i.e., affinity graph), and then select the top nodes as its k nearest neighbors. We can use different attention kernels for calculating affinity graph to determine nearest neighborhood and for calculating the normalized attention, similar to use key-value based attention mechanism (Vaswani et al., 2017).
2.3.1 Layer-specific dynamic affinity graph
kNN attention pooling layer can be applied to a collection of objects regardless a graph (e.g., edges between objects) is given or not. If a graph is given, we can directly use the graph to determine the neighborhood in Eq. 2.2 and Eq. 2.3, which is the same as in Graph Attention Model (Veličković et al., 2017). If the degree of the graph is too high, and some nodes have very large neighborhoods, then we can select only nearest neighbors for calculating attention when computational cost is a big concern.
Regardless the graph is given or not, we can always calculate an affinity graph based on node features using some similarity metric including attention kernels discussed in Sec. 2.3. Our AffinityNet model (Figure 1) may contain multiple kNN pooling layers stacked together. We thus can calculate an affinity graph using learned node features representations from each layer. Graphs calculated using higher-level features may be more informative for separating different classes.
Besides, we can use the graph calculated using features from the previous layer to determine the k-nearest-neighborhood for the next layer. This can be seen as an implicit regularizer preventing the learned representation from drifting away from previous layer too much in a single layer operation.
Mathematically, for layer , we can calculate a layer-specific dynamic affinity graph using Eq. 2.3.1.
In Eq. 2.3.1, is the given graph if available. When not available, we can simply set . is node-feature-derived graph for layer . There is a good reason to include in Eq. 2.3.1, which is a ‘‘natural regularizer’’ making the currently layer incorporate information from previous layer and preventing the current layer drifting away from the previous layer too rapidly. If we do not use node-feature-derived graph from previous layer, we can simply set in Eq. 2.3.1.
If the input of AffinityNet model consists of objects, then we are essentially learning an affinity graph for these objects. In this sense, we also call our framework affinity network learning.
2.3.2 Semi-supervised few-shot learning
Semi-supervised few-shot learning (Ravi and Larochelle, 2017; Kingma et al., 2014; Kipf and Welling, 2016; Rasmus et al., 2015) only allows using very few labeled instances to train the model and requires model to generalize well. Our proposed AffinityNet model consisting of feature attention and kNN pooling layers can perform a good job towards this goal.
For cancer patient clustering problems, we usually have several hundred of patients in a study. If we can obtain a few labeled training examples (e.g., human experts can manually assign labels for some patients), we can use AffinityNet model for semi-supervised learning. The input of the model is the patient-feature matrix consisting of all patients, but we only backpropagate the classification error for those labeled patients. Different from conventional neural network where each instance is independently trained, AffinityNet can utilize unlabeled instances for calculating kNN attention-based representations in the whole sample pool. In a sense, kNN attention pooling layer performs both nonlinear transformation and ‘‘clustering’’ (attracting similar instances together in the learned feature space) during training. Even though the labels of most patients are unknown, their feature representations can be used for learning a global affinity graph, which is helpful to cluster or classify all patients in the cohort.
When dealing with very large graphs, we can input a small batch consisting of a partial graph to AffinityNet model to reduce possible computational burden. Though each batch may contain different instances, kNN pooling layer can still work well with attention mechanism. Our PyTorch implementation of AffinityNet can even handle the extreme case where only one instance is fed into the model at a time, in which case the AffinityNet model operates as conventional deep learning model to learn a nonlinear transformation only without pooling operation.
3.1.1 The power of feature attention layer
Following four independent 2-dimensional Gaussian distribution with mean , we generated 4000 points belonging to four clusters with four different colors corresponding to true cluster assignments in Figure 2. We then added 40-dimensional Gaussian noise. Thus each point has 42 dimensions, with the first two containing true signal, and the rest being random noise.
We constructed two models to predict class labels: ‘‘Neural Net’’: a neural network model with an input layer (42-dimensional), a hidden layer (100-dimensional) and an output layer (4-dimensional); ‘‘Affinity Net’’: same as ‘‘Neural Net’’ model except adding one feature attention layer with kNN attention pooling to it after input layer.
We randomly select 1% of data (40 out of 4000 points) for training two models and compare accuracies on test set. Surprisingly, by only training 1% of the data, our model with feature attention layer can successfully select the true signal features and achieve 98.2% accuracy on the test set. By contrast, plain neural network model only achieves 46.9% accuracy on test set.
In Figure 3, the upper panel plots the loss and accuracy curves during training for ‘‘Affinity Net’’ (topleft) and ‘‘Neural Net’’ (topright). Even though both models achieve 100% training accuracy within a few iterations, ‘‘Affinity net’’ generalizes better than plain neural network model. There is a big gap between training and test accuracy curves for ‘‘Neural Net’’ model when training data is small.
Strikingly, the good generalization of our model relies on the success of feature attention layer picking up the true signals from the noise. The lower panel of Figure 3 shows the learned weights for 42-dimensional input features, with the red dots corresponding to true signals and blue dots noise. The weights of true signal learned from ‘‘Affinity Net’’ are much higher than those noise, while plain ‘‘Neural Net’’ does not select the true signal very well.
3.2 Cancer patients classification
Harmonized kidney and uterus cancer datasets were downloaded from Genomic Data Commons Data Portal (https://portal.gdc.cancer.gov) (Grossman et al., 2016). Kidney cancer has three disease types, and uterus cancer has two. We are trying to classify each tumor sample into its disease types for uterus and kidney cancer separately.
The number of samples from each disease type is summarized in Table 1. Both kidney cancer and uterus cancer have unbalanced classes (i.e., one class has much less samples than the other).
|Cancer type||Disease type||Total|
|Renal Clear Cell Carcinoma||316|
|Renal Papillary Cell Carcinoma||273|
|Uterus||Uterine Corpus Endometrial Carcinoma||421||475|
We calculated the standard deviation of gene expression values for each gene across samples within a cancer type (i.e., kidney or uterus). and selected top 1000 most variant gene expression features as input to our model.
We compare our model (‘‘Affinity Net’’) with five other methods: ‘‘Neural Net’’ (conventional deep learning model), ‘‘SVM’’, ‘‘Naive Bayes’’, ‘‘Random Forest’’, and ‘‘Nearest Neighbors’’ (kNN).
Our model (‘‘Affinity Net’’) consists of a feature attention layer, a kNN attention pooling layer (100 hidden units), and a fully connected layer with softmax classifier. For kNN attention pooling layer, we use ‘‘cosine similarity’’ kernel and set the number of nearest neighbors (kidney cancer) and (uterus cancer).
A conventional two layer fully connected deep learning model (‘‘Neural Net’’) with one hidden layer (100 hidden units) and an output layer was constructed for comparison with ‘‘Affinity Net’’. For both ‘‘Affinity Net’’ and ‘‘Neural Net’’, we use ReLU() nonlinear activation in the hidden layer.
Since the input dimension is 1000 (i.e., top 1000 most variant gene expressions), the total parameters of ‘‘Neural Net’’ is 100,403 parameters for kidney cancer with three classes (i.e., disease types), and 100,202 parameters for uterus cancer with two classes. Our model ‘‘Affinity Net’’ has 101,403 parameters and 101,202 parameters for kidney and uterus cancer, respectively. So our model only has 1000 more parameters than ‘‘Neural Net’’ to facilitate fair comparisons. We do not use more layers in our neural network models because there are only several hundred samples to train, and larger models are more likely to overfit.
We used the implementation from scikit-learn (http://scikit-learn.org) for ‘‘Naive Bayes’’, ‘‘SVM’’, ‘‘Nearest Neighbors’’, and ‘‘Random Forest’’ with default setting.
We progressively increase the training portion from 1% to 70% (i.e., 1%, 10%, 20%, 30%, 40%, 50%, 60% and 70%), and report adjusted mutual information (AMI) on the test set (Table 2 and Table 3). AMI is an adjustment of the Mutual Information (MI) score to account for chance, which is suitable to measure the performance of clustering and classification with multiple unbalanced classes (e.g., one class may have much more samples than others).
We run experiments 20 times with different random seeds to generate different training and test set. For each run, the training and test set for all six methods are identical. We report the mean AMI scores for top 10 runs (results depending on the few selected training examples and other randomness) for all methods in Table 2 and Table. 3, which show adjusted mutual information scores for kidney and uterus cancer, respectively.
For both cancer types, our model clearly outperforms all other models, especially when training portion is small. For example, when trained on only 1% of the data, our model can achieve AMI=0.84 for kidney cancer and AMI=0.62 for uterus cancer (Table 2 and Table 3), while other methods performs badly with few training examples. This suggests our model is highly data efficient in terms of requiring less labeled training data to generalize well. One major reason for this is that kNN attention pooling layer is performing ‘‘clustering’’ in a sense during training, and it is less likely to overfit to a small number of training examples with labels. The input of kNN attention pooling layer can contain not only labeled training examples but also unlabeled examples. It performs semi-supervised learning with a few labeled examples as guidance for finding ‘‘clusters’’ among all data points. ‘‘Neural Net’’ and other methods do not perform well with few labeled training examples because they tend to overfit the training set. As more training data is available, other methods including ‘‘Neural Net’’ are improving rapidly. In this case, ‘‘Neural Net’’ model does not outperform traditional machine learning techniques such as ‘‘SVM’’ because the dataset is quite small and the power of deep learning is only manifested when large amounts of data is available.
In Table 2, note that for kidney cancer, unlike other methods, our model does not improve with more training data, partly because there are a few very hard cases in kidney cancer dataset, while all other cases are almost linearly separable. Our model can easily pick up the linearly separable clusters with only a few training examples, but it is hard to separate very hard cases even when more training data is available.
Uterus cancer dataset is highly unbalanced with one class being much smaller than the other, and thus it is much harder to achieve high adjusted mutual information (AMI). As shown in Table 3, ‘‘Affinity Net’’ achieves AMI when trained on approximately 1% of the data (i.e., randomly chosen 1 samples from disease Uterine Carcinosarcoma and 4 samples from the other disease type), and performs significantly better than other models even as the training portion increases to 70%. This suggests ‘‘Affinity Net’’ works well on unbalanced data, while other methods are inadequate.
3.3 Semi-supervised clustering
Cancer patient clustering and disease subtype discovery is very challenging because of small sample size and lack of enough training examples with groundtruth labels. If we can obtain label information for a few samples, we can use ‘‘Affinity Net’’ for semi-supervised clustering (Weston et al., 2012).
Both ‘‘Affinity Net’’ and ‘‘Neural Net’’ can learn a new feature representation through multiple nonlinear transformations. For classification model, the new feature representation is usually fed into a fully connected layer followed by a softmax nonlinearity. We can train our model with a few labeled examples, and then feed all data points with raw features to neural network model and get the transformed feature representations.
For ‘‘Affinity Net’’, if we use all data points during training but only backpropagate on only labeled training examples, we get learned new representations for all data points after training is finished. For conventional neural network, since each data point is independently trained, we only use labeled examples during training. After training, we can then generate new feature representations for all data points using the trained neural network model.
In order to evaluate the quality of learned feature representation with a few training examples, we perform clustering using these transformed features and using the original features, and compare them with groundtruth class labels.
We compare the performance using our proposed ‘‘Affinity Net’’ and ‘‘Neural Net’’ on kidney data set. We randomly selected 1% of data for training, and ran experiments 30 times. After training, we performed spectral clustering on transformed patient-feature matrix. Figure 4 shows the adjusted mutual information scores for all 30 runs using ‘‘Affinity Net’’ and ‘‘Neural Net’’. We also performed spectral clustering on the original patient-feature matrix as baseline method (AMI = 0.71, blue dotted line in the figure). Our model outperforms ‘‘Neural Network’’ model (, Wilcoxon signed rank test) and baseline, while ‘‘Neural Network’’ is slightly below the baseline because it overfits a few training examples. This shows that while both ‘‘Neural Net’’ and ‘‘Affinity Net’’ have approximately the same number of model parameters, only ‘‘Affinity Net’’ can learn a good feature transformation by facilitating semi-supervised few-shot learning with kNN attention pooling layer and feature attention layer.
3.4 Combine with Cox model for survival analysis
For many cancer genomics studies, cancer subtype information is not known, but patient survival information is available. We replace last classifier layer in the model (as shown in Fig 1) with a regression layer following cox proportional hazards model (Mobadersany et al., 2018; Fox, 2002). We use backpropagation to learn model parameters that maximize partial likelihood in Cox model.
We perform experiments on kidney dataset that has more than 600 samples. We progressively increase the training portion from 10% to 40%. We use 30% of data as validation and the remaining as test set. As baseline method, we use age, gender and known disease types as covariates to fit a Cox model. We run experiments 20 times with random seeds, and report concordance index on the test set for both our model and baseline Cox model (Figure 5).
In Figure 5, the light blue boxplots on the left side correspond to baseline method (Cox model on age, gender and disease types), and the light green ones correspond to our model. The reported p-value between our model in baseline method for training 10% data is calculated using Wilcoxon signed rank test. Our model outperforms baseline models by a significant margin (Table. 4 shows the mean concordance index in different settings).
|Baseline Cox Model||0.601||0.602||0.616||0.618|
There are three disease types of kidney cancer. We use our best model trained on 10% of data to calculate the hazard rates for all kidney cancer patients in the dataset, and split them into three groups with low, intermediate, and high hazard rates. The proportion of three groups is the same as the three disease types. Figure 6 shows Kaplan Meier plot for both three known disease types (dotted line) and three groups based on predicted hazard rate (AffinityNet-low, AffinityNet-int., and AffinityNet-high in the figure). The p-value of log rank test of our predicted groups is , while the p-value of log rank test for three known disease types is , indicating our model can better separate patients with different survival time.
4 Discussion and Conclusion
Deep learning has achieved great success in computer vision, natural language processing, and speech recognition, where features are well structured (pixels, words, and audio signals are well ordered) and a large amount of training data is available. However, in biomedical research, the training sample size is usually small while the feature dimension is very high, where deep learning models tend to overfit the training data but fail to generalize. To alleviate this problem in the patient clustering related tasks, we propose AffinityNet model that contains feature attention and kNN attention pooling layers to facilitate few-shot learning.
Feature attention layer can be seen as a special case of metric learning. This simple layer is especially helpful for cases where most input features are irrelevant or noisy to the specific task. It can be incorporated into deep learning model and select important features automatically with a normalized non-negative weight learned for each feature. kNN attention pooling layer generalizes graph attention model to cases where no graph information is available. Without known graph information, kNN attention pooling layer can use attention kernels to calculate an affinity graph. kNN attention pooling layers essentially adds a ‘‘clustering’’ operation (‘‘forcing’’ similar objects to have similar representations through attention-based pooling) after nonlinear feature transformation, which can be seen as an implicit regularizer for classification-related tasks. Both feature attention layer and kNN attention pooling layers can be plugged into a deep learning model as basic building blocks.
Building from feature attention and kNN pooling layer, AffinityNet model is more data efficient and is very effective for semi-supervised few-shot learning compared with conventional deep learning model. We have conducted extensive experiments using AffinityNet on two cancer genomics datasets and achieved good results.
Though we did not discuss multi-view data in this paper due to page limit, our PyTorch implementation of AffinityNet did take multi-view data into consideration and included view fusion layer for multi-view learning (Wang et al., 2015). We hypothesize it is better to process each view separately to learn a high-level representation for each view and then combine them with high-level feature representations.
AffinityNet alleviates the problem of lack of a sufficient amount of labeled training data by utilizing unlabeled data with kNN attention pooling, and can be used to analyze large bulk of cancer genomics data for patient clustering and disease subtype discovery. Future work may focus on designing deep learning modules that can incorporate biological knowledge for various tasks.
This work was supported in part by the US National Science Foundation under grants NSF IIS-1218393 and IIS-1514204. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation.
- Ba et al. (2016) Ba, J. L., Kiros, J. R., and Hinton, G. E. (2016). Layer normalization. arXiv preprint arXiv:1607.06450.
- Bahdanau et al. (2014) Bahdanau, D., Cho, K., and Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
- Banino et al. (2018) Banino, A., Barry, C., Uria, B., Blundell, C., Lillicrap, T., Mirowski, P., Pritzel, A., Chadwick, M. J., Degris, T., Modayil, J., et al. (2018). Vector-based navigation using grid-like representations in artificial agents. Nature.
- Bellet et al. (2013) Bellet, A., Habrard, A., and Sebban, M. (2013). A survey on metric learning for feature vectors and structured data. arXiv preprint arXiv:1306.6709.
- Bengio et al. (2013) Bengio, Y., Courville, A., and Vincent, P. (2013). Representation learning: A review and new perspectives. IEEE transactions on pattern analysis and machine intelligence, 35(8), 1798–1828.
- Fox (2002) Fox, J. (2002). Cox proportional-hazards regression for survival data. An R and S-PLUS companion to applied regression, 2002.
- Grossman et al. (2016) Grossman, R. L., Heath, A. P., Ferretti, V., Varmus, H. E., Lowy, D. R., Kibbe, W. A., and Staudt, L. M. (2016). Toward a shared vision for cancer genomic data. New England Journal of Medicine, 375(12), 1109–1112.
- Hamilton et al. (2017a) Hamilton, W., Ying, Z., and Leskovec, J. (2017a). Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pages 1025–1035.
- Hamilton et al. (2017b) Hamilton, W. L., Ying, R., and Leskovec, J. (2017b). Representation learning on graphs: Methods and applications. arXiv preprint arXiv:1709.05584.
- He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778.
- Ioffe and Szegedy (2015) Ioffe, S. and Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pages 448–456.
- Jing et al. (2017) Jing, Y., Yang, Y., Feng, Z., Ye, J., and Song, M. (2017). Neural style transfer: A review. arXiv preprint arXiv:1705.04058.
- Kingma et al. (2014) Kingma, D. P., Mohamed, S., Rezende, D. J., and Welling, M. (2014). Semi-supervised learning with deep generative models. In Advances in Neural Information Processing Systems, pages 3581–3589.
- Kipf and Welling (2016) Kipf, T. N. and Welling, M. (2016). Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907.
- Krizhevsky et al. (2012) Krizhevsky, A., Sutskever, I., and Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pages 1097–1105.
- LeCun et al. (2015) LeCun, Y., Bengio, Y., and Hinton, G. (2015). Deep learning. nature, 521(7553), 436.
- Ma and Zhang (2017) Ma, T. and Zhang, A. (2017). Integrate multi-omic data using affinity network fusion (anf) for cancer patient clustering. arXiv preprint arXiv:1708.07136.
- Mobadersany et al. (2018) Mobadersany, P., Yousefi, S., Amgad, M., Gutman, D. A., Barnholtz-Sloan, J. S., Velázquez Vega, J. E., Brat, D. J., and Cooper, L. A. D. (2018). Predicting cancer outcomes from histology and genomics using convolutional networks. Proceedings of the National Academy of Sciences.
- Rasmus et al. (2015) Rasmus, A., Berglund, M., Honkala, M., Valpola, H., and Raiko, T. (2015). Semi-supervised learning with ladder networks. In Advances in Neural Information Processing Systems, pages 3546–3554.
- Ravi and Larochelle (2017) Ravi, S. and Larochelle, H. (2017). Optimization as a model for few-shot learning. International Conference on Learning Representations (ICLR).
- Russakovsky et al. (2015) Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., Berg, A. C., and Fei-Fei, L. (2015). ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV), 115(3), 211–252.
- Shen et al. (2012) Shen, R., Mo, Q., Schultz, N., Seshan, V. E., Olshen, A. B., Huse, J., Ladanyi, M., and Sander, C. (2012). Integrative subtype discovery in glioblastoma using icluster. PloS one, 7(4), e35236.
- Silver et al. (2016) Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., Van Den Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., et al. (2016). Mastering the game of go with deep neural networks and tree search. nature, 529(7587), 484–489.
- Sutskever et al. (2014) Sutskever, I., Vinyals, O., and Le, Q. V. (2014). Sequence to sequence learning with neural networks. In Advances in neural information processing systems, pages 3104–3112.
- Szegedy et al. (2017) Szegedy, C., Ioffe, S., Vanhoucke, V., and Alemi, A. A. (2017). Inception-v4, inception-resnet and the impact of residual connections on learning. In AAAI, volume 4, page 12.
- Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. (2017). Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000–6010.
- Veličković et al. (2017) Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., and Bengio, Y. (2017). Graph attention networks. arXiv preprint arXiv:1710.10903.
- Wang et al. (2014) Wang, B., Mezlini, A. M., Demir, F., Fiume, M., Tu, Z., Brudno, M., Haibe-Kains, B., and Goldenberg, A. (2014). Similarity network fusion for aggregating data types on a genomic scale. Nature Methods, 11, 333.
- Wang et al. (2015) Wang, W., Arora, R., Livescu, K., and Bilmes, J. (2015). On deep multi-view representation learning. In International Conference on Machine Learning, pages 1083–1092.
- Weston et al. (2012) Weston, J., Ratle, F., Mobahi, H., and Collobert, R. (2012). Deep learning via semi-supervised embedding. In Neural Networks: Tricks of the Trade, pages 639–655. Springer.