ProtoAttend: Attention-Based Prototypical Learning

ProtoAttend: Attention-Based Prototypical Learning

Sercan Ö. Arık
Google Cloud AI
Sunnyvale, CA
soarik@google.com
&Tomas Pfister
Google Cloud AI
Sunnyvale, CA
tpfister@google.com
Abstract

We propose a novel inherently interpretable machine learning method that bases decisions on few relevant examples that we call prototypes. Our method, ProtoAttend, can be integrated into a wide range of neural network architectures including pre-trained models. It utilizes an attention mechanism that relates the encoded representations to samples in order to determine prototypes. The resulting model outperforms state of the art in three high impact problems without sacrificing accuracy of the original model: (1) it enables high-quality interpretability that outputs samples most relevant to the decision-making (i.e. a sample-based interpretability method); (2) it achieves state of the art confidence estimation by quantifying the mismatch across prototype labels; and (3) it obtains state of the art in distribution mismatch detection. All this can be achieved with minimal additional test time and a practically viable training time computational cost.

1 Introduction

Deep neural networks have been pushing the frontiers of artificial intelligence (AI) by yielding excellent performance in numerous tasks, from understanding images (He et al., 2016) to text (Conneau et al., 2016). Yet, high performance is not always a sufficient factor - as some real-world deployment scenarios might necessitate that an ideal AI system is ‘interpretable’, such that it builds trust by explaining rationales behind decisions, allow detection of common failure cases and biases, and refrains from making decisions without sufficient confidence. In their conventional form, deep neural networks are considered as black-box models – they are controlled by complex nonlinear interactions between many parameters that are difficult to understand. There are numerous approaches, e.g. (Kim et al., 2018a; Erhan et al., 2009; Zeiler and Fergus, 2013; Simonyan et al., 2013), that bring post-hoc explainability of decisions to already-trained models. Yet, these have the fundamental limitation that the models are not designed for interpretability. There are also approaches on the redesign of neural networks towards making them inherently-interpretable, as in this paper. Some notable ones include sequential attention (Bahdanau et al., 2015), capsule networks (Sabour et al., 2017), and interpretable convolutional filters (Zhang et al., 2018).

Figure 1: ProtoAttend bases the decision on a few prototypes from the database. This enables interpretability of the prediction (by visualizing the highest weight prototypes) and confidence estimation for the decision (by measuring agreement across prototype labels).

We focus on inherently-interpretable deep neural network modeling with the foundations of prototypical learning. Prototypical learning decomposes decision making into known samples (see Fig. 1), referred here as prototypes. We base our method on the principle that prototypes should constitute a minimal subset of samples with high interpretable value that can serve as a distillation or condensed view of a dataset (Bien and Tibshirani, 2012). Given that the number of objects a human can interpret is limited (Miller, 1956), outputting few prototypes can be an effective approach for humans to understand the AI model behavior. In addition to such interpretability, prototypical learning: (1) provides an efficient confidence metric by measuring mismatches in prototype labels, allowing performance to be improved by refraining from making predictions in the absence of sufficient confidence, (2) helps detect deviations in the test distribution by measuring mismatches in prototype labels that represent the support of the training dataset, and (3) enables performance in the high label noise regime to be improved by controlling the number of selected prototypes. Given these motivations, prototypes should be controllable in number, and should be perceptually relevant to the input in explaining the decision making task. Prototype selection in its naive form is computationally expensive and perceptually challenging (Bien and Tibshirani, 2012). We design ProtoAttend to address this problem in an efficient way. Our contributions can be summarized as follows:

  1. [noitemsep, nolistsep, leftmargin=*]

  2. We propose a novel method, ProtoAttend, for selecting input-dependent prototypes based on an attention mechanism between the input and prototype candidates. ProtoAttend is model-agnostic and can even be integrated with pre-trained models.

  3. ProtoAttend allows interpreting the contribution of each prototype via the attention outputs.

  4. For a ‘condensed view’, we demonstrate that sparsity in weights can be efficiently imposed via the choice of the attention normalization and additional regularization.

  5. On image, text and tabular data, we demonstrate the four key benefits of ProtoAttend: interpretability, confidence control, diagnosis of distribution mismatch, and robustness against label noise. ProtoAttend yields superior quality for sample-based interpretability, better-calibrated confidence scoring, and more sensitive out-of-distribution detection compared to alternative approaches.

  6. ProtoAttend enables all these benefits via the same architecture and method, while maintaining comparable overall accuracy.

2 Related Work

Prototypical learning:  The principles of ProtoAttend are inspired by (Bien and Tibshirani, 2012). They formulate prototype selection as an integer program and solve it using a greedy approach with linear program relaxation. It seems unclear whether such approaches can be efficiently adopted to deep learning. (Chen et al., 2018) and (Li et al., 2018) introduce a prototype layer for interpretability by replacing the conventional inner product with a distance computation for perceptual similarity. In contrast, our method uses an attention mechanism to quantify perceptual similarity and can choose input-dependent prototypes from a large-scale candidate database. (Yeh et al., 2018) decomposes the prediction into a linear combination of activations of training points for interpretability using representer values. The linear decomposition idea also exists in ProtoAttend, but the weights are learned via an attention mechanism and sparsity is encouraged in the decomposition. In (Koh and Liang, 2017), the training points that are the most responsible for a given prediction are identified using influence functions via oracle access to gradients and Hessian-vector products.

Metric learning:  Metric learning aims to find an embedding representation of the data where similar data points are close and dissimilar data pointers are far from each other. ProtoAttend is motivated by efficient learning of such an embedding space which can be used to decompose decisions. Metric learning for deep neural networks is typically based on modifications to the objective function, such as using triplet loss and N-pair loss (Sohn, 2016; Cui et al., 2016; Hoffer and Ailon, 2014). These yield perceptually meaningful embedding spaces yet typically require a large subset of nearest neighbors to avoid degradation in performance (Cui et al., 2016). (Kim et al., 2018b) proposes a deep metric learning framework which employs an attention-based ensemble with a divergence loss so that each learner can attend to different parts of the object. Our method has metric learning capabilities like relating similar data points, but also performs well on the ultimate supervised learning task.

Attention-based few-shot learning:  Some of our inspirations are based on recent advances in attention-based few-shot learning. In (Vinyals et al., 2016), an attention mechanism is used to relate an example with candidate examples from a support set using a weighted nearest-neighbor classifier applied within an embedding space. In (Ren et al., 2018), incremental few-shot learning is implemented using an attention attractor network on the encoded and support sets. In (Snell et al., 2017), a non-linear mapping is learned to determine the prototype of a class as the mean of its support set in the embedding space. During training, the support set is randomly sampled to mimic the inference task. Overall, the attention mechanism in our method follows related principles but fundamentally differs in that few-shot learning aims for generalization to unseen classes whereas the goal of our method is robust and interpretable learning for seen classes.

Uncertainty and confidence estimation:  ProtoAttend takes a novel perspective on the perennial problem of quantifying how much deep neural networks’ predictions can be trusted. Common approaches are based on using the scores from the prediction model, such as the probabilities from the softmax layer of a neural network, yet it has been shown that the raw confidence values are typically poorly calibrated (Guo et al., 2017). Ensemble of models (Lakshminarayanan et al., 2017) is one of the simplest and most efficient approaches, but significantly increases complexity and decreased interpretability. In (Papernot and McDaniel, 2018), the intermediate representations of the network are used to define a distance metric, and a confidence metric is proposed based on the conformity of the neighbors. (Jiang et al., 2018), proposes a confidence metric based on the agreement between the classifier and a modified nearest-neighbor classifier on the test sample. In (DeVries and Taylor, 2018), direct inference of confidence output is considered with a modified loss. Another direction of uncertainty and confidence estimation is Bayesian neural networks that return a distribution over the outputs (Kendall and Gal, 2017) (Mullachery et al., 2018) (Kendall and Gal, 2017).

3 ProtoAttend: Attention-based Prototypical Learning

Consider a training set with labels, . Conventional supervised learning aims to learn a model that minimizes a predefined loss 111 represents the trainable parameters for and is sometimes not show for notation convenience. at each iteration, where is the batch size for training. Our goal is to impose that decision making should be based on only a small number of training examples, i.e. prototypes, such that their linear superposition in an embedding space can yield the overall decision and the superposition weights correspond to their importance. Towards this goal, we propose defining a solutions to prototypical learning with the following six principles:

  1. [label=., noitemsep, nolistsep]

  2. encodes all relevant information of for the final decision. considers the global distribution of the samples, i.e. learns from all . Although all the information in training dataset is embodied in the weights of the encoder222Training of may also involve initializing with pre-trained models or transfer learning., we construct the learning method in such a way that decision is dominated by the prototypes with high weights.

  3. From the encoded information, we can find a decision function so that the mapping is close to the ground truth , in a consistent way with conventional supervised learning.

  4. Given candidates to select the prototypes from, there exists weights (where and ), such that the decision (where ) is close to the ground truth .

  5. When the linear combination is considered, prototypes with higher weights have higher contribution in the decision .

  6. The weights should be sparse – only a controllable amount of weights should be non-zero. Ideally, there exists an efficient mechanism for outputting to control the sparsity without significantly affecting performance.

  7. The weights depend on the relation between input and the candidate samples, , based on their perceptual relation for decision making. We do not introduce any heuristic relatedness metric such as distances in the representation space, but we allow the model to learn the relation function that helps the overall performance.

Learning involves optimization of the parameters of the corresponding functions. If the proposed principles (such as reasoning from the linear combination of embeddings or assigning relevance to the weights) are not imposed during training but only at inference, a high performance cannot be obtained due to the train-test mismatch, as the intermediate representations can be learned in an arbitrary way without any necessities to satisfy them.333For example, commonly-used distance metrics in the representation spaces fail at determining perceptual relevance of between samples when the model is trained in a vanilla way (Sitawarin and Wagner, 2019). The subsequent section presents ProtoAttend and training procedure to implement it.

3.1 Network architecture and training

The principles above are conditioned on efficient learning of an encoding function to encode the relevant information for decision making, a relation function to determine the prototype weights, and a final decision making block to return the output. Conventional supervised learning comprises the encoding and decision blocks. On the other hand, it is challenging to design a learning method with a relation function with a reasonable complexity. To this end, we adapt the idea of attention (Corbetta and Shulman, 2002; Vaswani et al., 2017), where the model focuses on an adaptive small portion of input while making the decision. Different from conventional employment of attention in sequence or visual learning, we propose to use attention at sample level, such that the attention mechanism is used to determine the prototype weights by relating the input and the candidate samples via alignment of their keys and queries. Fig. 2 shows the proposed architecture for training and inference. The three main blocks are described below:

Figure 2: ProtoAttend method for training and testing. Shared encoder between input samples and the candidate samples generates input representations, that are mapped to key, query and value embeddings (with a single nonlinear layer). The alignment between keys and queries determines the weights of the prototypes, and the linear combination of the values determines the final decision. Conformity of the prototype labels is used as a confidence metric.

Encoder:  A trainable encoder is employed to transform input samples (note that may be 1 at inference) and samples from the database of prototype candidates (note that may be as large as the entire training dataset at inference) into keys, queries and values. The encoder is shared and jointly updated for the input samples and prototype candidate database, to learn a common representation space for the values. The encoder architecture can be based on any trainable discriminative feature mapping function, e.g. ResNet (He et al., 2016) for images, with the modification of generating three types of embeddings. For mapping of the last encoder layer to key, query and value embeddings, we simply use a single fully-connected layer with a nonlinearity, separately for each.444There are other viable options for the mapping but we restrict it to a single layer to minimize the additional number of trainable parameters, which becomes negligible in most cases. For input samples, and denote the values and queries, and for candidate database samples and denote the keys and values.

Relational attention:  The relational attention yields the weight between the sample and candidate, , via alignment of the corresponding key and query in dot-product attention form555We use to denote the row of .:

(1)

where is a normalization function to satisfy and for which we consider softmax and sparsemax (Martins and Astudillo, 2016)666Sparsemax encourages sparsity by mapping the Euclidean projection onto the probabilistic simplex.. The choice of the normalization function is an efficient mechanism to control the sparsity of the prototype weights, as demonstrated in experiments. Note that the relational attention mechanism does not introduce any extra trainable parameters.

Decision making:  The final decision block simply consists of a linear mapping from a convex combination of values that results in the output . Consider the convex combination of value embeddings, parameterized by :

(2)

For , is the conventional supervised learning loss (ignoring the relational attention mechanism) that can only impose principles (i) and (ii), but not the principles (iii)-(vi). A high accuracy for merely indicates that the value embedding space represents each input sample accurately. For , encourages the principles (i), (iii)-(iv), but not the principles (ii) and (vi).777For example, simply assigning non-zero weights to another predetermined class, prototypical learning method can obtain perfect accuracy, but the assignment of predetermined class would be arbitrary. A high accuracy for indicates that the linear combination of value embeddings accurately maps to the decision. For (vi), we propose that there should be a similar output mapping for the input and prototypes, for which we encourage high accuracy for both and with a loss term that is a mixture of and or guidance with an intermediate term, as , is required. Lastly, when , we obtain the condition that the input sample itself has the largest contribution in the linear combination. Intuitively, the sample itself should be more relevant for the output compared to other samples, so the principles (iii) and (iv) can be encouraged. We propose and compare different training objective functions in Table 1. We observe that the last four are all viable options as the training objective, with similar performance. We choose the last one for the rest of the experiments, as in some cases, slightly better prototypes are observed qualitatively (see Sect. 5.2 for further discussion).

Training objective function Acc. % for Acc. % for
94.28 13.13 0.029 0.194
10.92 94.21 0.103 0.002
94.01 94.25 0.927 0.049
+ 94.37 94.38 0.931 0.047
+ 94.14 94.18 0.927 0.049
+ + 94.37 94.45 0.928 0.047
Table 1: Ablation study. Impact of various training losses on ProtoAttend with softmax attention for Fashion-MNIST. is the training iteration index and is the total number of iterations.

To control the sparsity of the weights (beyond the choice of the attention operation), we also propose a sparsity regularization term with a coefficient in the form of entropy, , where is a small number for numerical stability. is minimized when has only 1 non-zero value.

3.2 Confidence scoring using prototypes

Figure 3: Impact of confidence on ProtoAttend accuracy. Reliability diagram for Fashion-MNIST, as in (Papernot and McDaniel, 2018). Bars (left axis) indicate the mean accuracy of predictions binned by confidence; the red line (right axis) shows the number of samples across bins.

ProtoAttend provides a linear decomposition (via value embeddings) of the decision into prototypes that have known labels. Ideally, labels of the prototypes should all be the same as the labels of the input. When prototypes with high weights belong to the same class, the model shall be more confident and a correct classification result is expected, whereas in the cases of disagreement between prototype labels, the model shall be less confident and the likelihood of a wrong prediction is higher. With the motivation of separating correct vs. incorrect decisions via its value, we propose a confidence score based on the agreement between the prototypes:

(3)

where is the indicator function. Table 1 shows the significant difference of the average confidence metric between correct vs. incorrect classification cases for the test dataset, as desired. In Fig. 3, the impact of confidence on accuracy is further analyzed with the reliability diagram as in (Papernot and McDaniel, 2018). When test samples are binned according to their confidence, it is observed that the bins with higher confidence yield much higher accuracy. There are small number of samples in the bins with lower confidence, and those tend to be the incorrect classification cases. In Section 4.4, the efficacy of confidence score in separating correct vs. incorrect classification is experimented in confidence-controlled prediction setting, demonstrating how much the prediction accuracy can be improved by refraining from small number of samples with low confidence at test time.

To further encourage confidence during training, we also consider a regularization term with a coefficient . is minimized when all prototypes with are from the same ground truth class with output .888Note that the gradients of this regularization term with respect to is either 0 or 1 and it is often insufficient to train the model itself from scratch. But it is observed to provide further improvements in some cases.

4 Experiments

4.1 Setup

We demonstrate the results of ProtoAttend for image, text and tabular data classification problems with different encoder architectures (see Supplementary Material for details). Outputs of the encoders are mapped to queries, keys and values using a fully-connected layer followed by ReLU. For values, layer normalization (Lei Ba et al., 2016) is employed for more stable training. A fully-connected layer is used in the decision making block, yielding logits for determining the estimated class. Softmax cross entropy loss is used as . Adam optimization algorithm is employed (Kingma and Ba, 2014) with exponential learning rate decay (with parameters optimized on a validation set). For image encoding, unless specified, we use the standard ResNet model (He et al., 2016). For text encoding, we use the very deep convolutional neural network (VDCNN) (Conneau et al., 2016) model, inputting sequence of raw characters. For tabular data encoding, we use an LSTM model (Hochreiter and Schmidhuber, 1997), which inputs the feature embeddings at every timestep. See Supplementary Material for implementation details, additional results and discussions.

4.2 Sparse explanations of decisions

Dataset Method Acc. % No. of prototypes
50 % 90 % 95 %
MNIST Baseline enc. 99.70 -
Softmax attn. 99.66 365 1324 1648
Sparsemax attn. 99.69 2 4 5
Fashion-MNIST Baseline enc. 94.74 -
Softmax attn. 94.42 712 2320 2702
Sparsemax attn. 94.42 4 10 11
Sparsemax attn. + sparsity reg. 94.47 1 2 2
CIFAR-10 Baseline enc. 91.97 -
Softmax attn. 91.69 317 1453 1898
Sparsemax attn. 91.44 5 14 16
Sparsemax attn. + sparsity reg. 91.26 2 3 4
DBPedia Baseline enc. 98.25 -
Softmax attn. 98.20 63 190 225
Sparsemax attn. 97.74 2 4 4
Income Baseline enc. 85.68 -
Softmax attn. 85.64 2263 9610 12419
Sparsemax attn. 85.58 20 57 67
Sparsemax attn. + sparsity reg. 85.41 3 6 7
Table 2: ProtoAttend achieves interpretability without significant degradation in performance. Accuracy and median number of prototypes to add up to 50%, 90% and 95% of the decision, quantified with prototype weights.

We foremost demonstrate that our inherently-interpretable model design does not cause significant degradation in performance. Table 2 shows the accuracy and the median number of prototypes required to add up to a particular portion of the decision999E.g. if the prototype weights are [0.2, 0.15, 0.15, 0.25, 0.1, 0.05, 0.28, 0.02], then 2 prototypes are required for 50% of the decision, 6 for 90% and 7 for 95%. for different prototypical learning cases. In all cases, very small accuracy gap is observed with the baseline encoder that is trained in conventional supervised learning way. The attention normalization function and sparsity regularization are efficient mechanisms to control the sparsity – the number of prototypes required is much lower with sparsemax attention compared to softmax attention and can be further reduced with sparsity regularization (see Supplementary Material for details). With a small decrease in performance, the number of prototypes can be reduced to just a handful.101010We observe that excessively high sparsity (e.g. to yield 1-2 prototypes in most cases) may sometimes decrease the quality of prototypes due to overfitting to discriminative features that are less perceptually meaningful. There is difference between datasets, as intuitively expected from the discrepancy in the degree of similarity between the intra-class samples.

(a) MNIST & Fashion MNIST
(b) Fruits
Figure 4: Example inputs and ProtoAttend prototypes for (a) MNIST (with sparsemax), Fashion-MNIST dataset (with sparsemax and sparsity regularization) and (b) Fruits (with sparsemax and sparsity regularization). For MNIST & Fashion-MNIST, prototypes typically consist of discriminative features such as the straight line shape for the digit 1, and the long heels and strips for the sandal. For Fruits, prototypes often correspond to the same fruit captured from a very similar angle.
Figure 5: Example inputs and ProtoAttend prototypes for DBPedia (with sparsemax). While classifying the inputs as athlete, prototypes have very similar sentence structure, words and concepts.
Figure 6: Example inputs and ProtoAttend prototypes for Adult Census Income (with sparsemax and sparsity regularization). For the first example, all prototypes have similar age, two share similar education level and one has the same occupation. For the second example, three prototypes have the same occupation, all work more than 40 hours/week, and three have postgraduate education.

Figs. 4, 5 and 6 exemplify prototypes for image, text and tabular data. In general, perceptually-similar samples are chosen as the prototypes with the largest weights. We also compare the relevant samples found by ProtoAttend with the methods of representer point selection (Yeh et al., 2018) and influence functions (Koh and Liang, 2017) (see Supplementary Material for details) on Animals with Attributes dataset. As shown in Fig. 7, our method finds qualitatively more relevant samples. This case also exemplifies the potential of our method for integration into pre-trained models by addition of simple layers for key, query and value generation.

Figure 7: Samples found by ProtoAttend vs. representer point selection (Yeh et al., 2018) and influence function (Koh and Liang, 2017) for the two examples from (Yeh et al., 2018) on Animals with Attributes dataset. See Supplementary Material for more examples.

4.3 Robustness to label noise

Noise level Test accuracy %
Baseline Dropout ProtoAttend
0.8 57.02 56.76 60.50
0.6 71.27 72.15 74.67
0.4 77.47 78.99 80.04
Table 3: Label noise ratio vs. accuracy for baseline encoder, dropout method (Arpit et al., 2017) (optimizing the keep probability) and ProtoAttend with sparsemax attention and sparsity regularization for CIFAR-10.

As prototypical learning with sparsemax attention aims to extract decision-making information from a small subset of training samples, it can be used to improve performance when the training dataset contains noisy labels (see Table 3). The optimal value111111For a fair comparison, we re-optimize the learning rate parameters on a separate validation set. of increases with higher noisy label ratios, underlining the increasing importance of sparse learning.

4.4 Confidence-controlled prediction

Figure 8: Confidence-controlled prediction. (a) Accuracy vs. ratio of samples for MNIST. We compare dkNN (Papernot and McDaniel, 2018) and prototypical learning (with softmax attention and =0.1) using the same network architecture from  (Papernot and McDaniel, 2018) without augmentation. (b) Accuracy vs. ratio of samples for CIFAR-10. We compare prototypical learning (with softmax attention and =0.1) with trust score (Jiang et al., 2018) and deep ensemble (Lakshminarayanan et al., 2017) methods for the same baseline encoder network architecture.

By varying the threshold for the confidence metric, a trade-off can be obtained for what ratio of the test samples that the model makes a prediction for vs. the overall accuracy it obtains on the samples above that threshold.121212Note that this trade-off is often more meaningful to consider rather than the metrics based on the actual value of confidence score itself, as methods may differ in how they define the confidence metric, and thus yield very different ranges and distributions for it. Figs. 8(a) and 8(b) demonstrate this trade-off and compare it to alternative methods. The sharper slope of the plots show that our method is superior to dkNN (Papernot and McDaniel, 2018) and trust score (Jiang et al., 2018), the methods based on quantifying the mismatch with nearest-neighbor samples, in terms of finding related samples. Although the baseline accuracy is higher with 4 ensemble networks obtained via deep ensemble (Lakshminarayanan et al., 2017), our method utilizes a single network and the additional accuracy gains by refraining from uncertain predictions is similar to our approach as shown by the similar slopes of the curves.

Overall, the baseline accuracy can be significantly improved by making less predictions. Compared to the state of the art models, our canonical method with simple and small models shows similar accuracy by making slightly fewer predictions – e.g. for MNIST, (Wan et al., 2013) achieves 0.21% error rate, that is obtained by our method refraining from only 0.45% of predictions using ResNet-32 and for DBpedia, (Sachan and Petuum, 2018) achieves 0.91% error, that is obtained by our method refraining from 3% of predictions using 9-layer VDCNN. In general, the smaller the number of prototypes, the smaller the trade-off space. Thus, softmax attention (which normally results in more prototypes) is better suited for confidence-controlled prediction compared to sparsemax (see Supplementary Material for more comparisons).

4.5 Out-of-distribution samples

Well-calibrated confidence scores at inference can be used to detect deviations from the training dataset. As the test distribution deviates from the training distribution, prototype weights tend to mismatch more and yield lower confidence scores. Fig. 9 (a) shows the ratio of samples above a certain confidence level as the test dataset deviates. Rotations deviate the distribution of test images from the training images, and cause significant degradation in confidence scores, as well as the overall accuracy. On the other hand, using test image from a different dataset, degrade them even further. Next, Fig. 9 (b) shows quantification of out-of-distribution detection with prototypical learning, using the method from (Hendrycks and Gimpel, 2016). ProtoAttend yields an AUC of 0.838, being on par with the-state of the art approaches (Hendrycks et al., ).

Figure 9: Out-of-distribution detection. (a) Ratio of samples above the confidence level for prototypical learning with softmax attention, trained with Fashion-MNIST, and tested on the shown datasets. E.g. if we assess the ratio of samples above confidence  0.9, it is far more likely that those samples to come from the same distribution with the training dataset. (b) ROC curve for in-distribution vs. out-of-distribution detection, using CIFAR-10 as in-distribution and SVHN as out-of-distribution, computed using the method from (Hendrycks and Gimpel, 2016) and compared to the proposed baseline in (Hendrycks and Gimpel, 2016). Softmax attention and confidence regularization () are used.

5 Computational Cost

ProtoAttend requires only a very small increase in the number of learning parameters (merely two extra small matrices for the fully-connected layers to obtain queries and keys). However, it does require a longer training time and has higher memory requirements to process the candidate database. At inference, keys and values for the candidate database can be computed only once and integrated into the model. Thus, the overhead merely becomes the computation of attention outputs (e.g. for CIFAR-10 model, the attention overhead at inference is less than 0.6 MFLOPs, orders of magnitude lower than the computational complexity of a ResNet model). During training on the other hand, both forward and backward propagation steps for the encoder need to be computed for all candidate samples and the total time is higher (e.g.  4.45 times slower to train until convergence for CIFAR-10 compared to the conventional supervised learning). The size of the candidate database is limited by the memory of the processor, so in practice we sample different candidate databases randomly from the training dataset at each iteration. For faster training, data and model parallelism approaches are straightforward to implement – e.g., different processors can focus on different samples, or they can focus on different parts of the convolution or inner product operations. Further computationally-efficient approaches may involve less frequent updates for candidate queries and values.

6 Conclusions

We propose an attention-based prototypical learning method, ProtoAttend, and demonstrate its usefulness for a wide range of problems on image, text and tabular data. By adding a relational attention mechanism to an encoder, prototypical learning enables novel capabilities. With sparsemax attention, it can base the learning on a few relevant samples that can be returned at inference for interpretability, and can also improves robustness to label noise. With softmax attention, it enables confidence-controlled prediction that can outperform state of the art results with simple architectures by simply making slightly fewer predictions, as well as enables detecting deviations from the training data. All these capabilities are achieved without sacrificing overall accuracy of the base model.

7 Acknowledgements

Discussions with Zizhao Zhang, Chih-Kuan Yeh, Nicolas Papernot, Ryan Takasugi, Andrei Kouznetsov, and Andrew Moore are gratefully acknowledged.

References

  • D. Arpit, S. Jastrzkbski, N. Ballas, D. Krueger, E. Bengio, et al. (2017) A Closer Look at Memorization in Deep Networks. arXiv:1706.05394. Cited by: Table 3.
  • D. Bahdanau, K. Cho, and Y. Bengio (2015) Neural machine translation by jointly learning to align and translate. In ICLR, Cited by: §1.
  • J. Bien and R. Tibshirani (2012) Prototype selection for interpretable classification. arXiv:1202.5933. Cited by: §1, §2.
  • C. Chen, O. Li, A. Barnett, J. Su, and C. Rudin (2018) This looks like that: deep learning for interpretable image recognition. arXiv:1806.10574. Cited by: §2.
  • A. Conneau, H. Schwenk, L. Barrault, and Y. LeCun (2016) Very deep convolutional networks for natural language processing. arXiv:1606.01781. Cited by: §A.2.1, §1, §4.1.
  • M. Corbetta and G. L. Shulman (2002) Control of goal-directed and stimulus-driven attention in the brain. Nature Reviews Neuroscience 3, pp. 201–215. Cited by: §3.1.
  • Y. Cui, F. Zhou, Y. Lin, and S. J. Belongie (2016) Fine-grained categorization and dataset bootstrapping using deep metric learning with humans in the loop. CVPR. Cited by: §2.
  • T. DeVries and G. W. Taylor (2018) Learning Confidence for Out-of-Distribution Detection in Neural Networks. arXiv:1802.04865. Cited by: §2.
  • D. Erhan, Y. Bengio, A. Courville, and P. Vincent (2009) Visualizing higher-layer features of a deep network. In Technical report, Cited by: §1.
  • C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger (2017) On calibration of modern neural networks. arXiv:1706.04599. Cited by: §2.
  • H. A. Haenssle, C. Fink, R. Schneiderbauer, F. Toberer, T. Buhl, et al. (2018) Man against machine: diagnostic performance of a deep learning convolutional neural network for dermoscopic melanoma recognition in comparison to 58 dermatologists. Annals of Oncology 29 (8), pp. 1836–1842. Cited by: Appendix C.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In CVPR, Cited by: §1, §3.1, §4.1.
  • D. Hendrycks and K. Gimpel (2016) A baseline for detecting misclassified and out-of-distribution examples in neural networks. arXiv:1610.02136. Cited by: Figure 9, §4.5.
  • [14] D. Hendrycks, M. Mazeika, and T. G. Dietterich Deep anomaly detection with outlier exposure. arXiv:1812.04606. Cited by: §4.5.
  • S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural Computation 9 (8), pp. 1735–1780. Cited by: §4.1.
  • E. Hoffer and N. Ailon (2014) Deep metric learning using triplet network. arXiv:1412.6622. Cited by: §2.
  • ISIC (2016) ISIC Archive. External Links: Link Cited by: §A.1.5, Appendix C.
  • A. F. Jerant, J. T. Johnson, C. D. Sheridan, and T. J. Caffrey (2000) Early detection and treatment of skin cancer. Am Fam Physician. Cited by: Appendix B.
  • H. Jiang, B. Kim, and M. R. Gupta (2018) To trust or not to trust a classifier. In NIPS, Cited by: §2, Figure 8, §4.4.
  • A. Kendall and Y. Gal (2017) What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?. In arXiv:1703.04977, Cited by: §2.
  • A. Kendall and Y. Gal (2017) What uncertainties do we need in bayesian deep learning for computer vision?. In NIPS, Cited by: §2.
  • B. Kim, M. Wattenberg, J. Gilmer, C. Cai, J. Wexler, F. Viegas, and R. Sayres (2018a) Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV). In ICML, Cited by: §1.
  • W. Kim, B. Goyal, K. Chawla, J. Lee, and K. Kwon (2018b) Attention-based ensemble for deep metric learning. In ECCV, Cited by: §2.
  • D. P. Kingma and J. Ba (2014) Adam: A method for stochastic optimization. In ICLR, Cited by: §4.1.
  • P. W. Koh and P. Liang (2017) Understanding Black-box Predictions via Influence Functions. In ICML, Cited by: §2, Figure 7, §4.2.
  • B. Lakshminarayanan, A. Pritzel, and C. Blundell (2017) Simple and scalable predictive uncertainty estimation using deep ensembles. In NIPS, Cited by: §2, Figure 8, §4.4.
  • J. Lei Ba, J. R. Kiros, and G. E. Hinton (2016) Layer Normalization. arXiv:1607.06450. Cited by: §A.1.1, §A.1.3, §A.1.4, §A.1.5, §A.2.1, §4.1.
  • O. Li, H. Liu, C. Chen, and C. Rudin (2018) Deep learning for case-based reasoning through prototypes: A neural network that explains its predictions. In AAAI, Cited by: §2.
  • A. F. T. Martins and R. F. Astudillo (2016) From softmax to sparsemax: A sparse model of attention and multi-label classification. In MLR, Cited by: §3.1.
  • G.A. Miller (1956) The magical number seven, plus or minus 2: some limits on our capacity for processing information. Psychological review 63, pp. 81–97. Cited by: §1.
  • V. Mullachery, A. Khera, and A. Husain (2018) Bayesian neural networks. arXiv:1801.07710. Cited by: §2.
  • N. Papernot and P. D. McDaniel (2018) Deep k-nearest neighbors: towards confident, interpretable and robust deep learning. arXiv:1803.04765. Cited by: §2, Figure 3, §3.2, Figure 8, §4.4.
  • M. Ren, R. Liao, E. Fetaya, and R. S. Zemel (2018) Incremental few-shot learning with attention attractor networks. arXiv:1810.07218. Cited by: §2.
  • S. Sabour, N. Frosst, and G. E. Hinton (2017) Dynamic routing between capsules. In NIPS, Cited by: §1.
  • D. S. Sachan and Petuum (2018) Revisiting lstm networks for semi-supervised text classification via mixed objective function. In KDD, Cited by: §4.4.
  • K. Simonyan, A. Vedaldi, and A. Zisserman (2013) Deep inside convolutional networks: visualising image classification models and saliency maps. arXiv:1312.6034. Cited by: §1.
  • C. Sitawarin and D. A. Wagner (2019) On the robustness of deep k-nearest neighbors. arXiv:1903.08333. Cited by: footnote 3.
  • J. Snell, K. Swersky, and R. S. Zemel (2017) Prototypical networks for few-shot learning. In NIPS, Cited by: §2.
  • K. Sohn (2016) Improved deep metric learning with multi-class n-pair loss objective. In NIPS, Cited by: §2.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, et al. (2017) Attention is all you need. arXiv:1706.03762. Cited by: §3.1.
  • O. Vinyals, C. Blundell, T. P. Lillicrap, K. Kavukcuoglu, and D. Wierstra (2016) Matching networks for one shot learning. In NIPS, Cited by: §2.
  • L. Wan, M. Zeiler, S. Zhang, Y. L. Cun, and R. Fergus (2013) Regularization of neural networks using dropconnect. In ICML, Cited by: §4.4.
  • C. Yeh, J. S. Kim, I. E. Yen, and P. Ravikumar (2018) Representer point selection for explaining deep neural networks. arXiv:1811.09720. Cited by: §A.1.6, Figure 13, Appendix B, §2, Figure 7, §4.2.
  • M. D. Zeiler and R. Fergus (2013) Visualizing and understanding convolutional networks. arXiv:1311.2901. Cited by: §1.
  • Q. Zhang, Y. N. Wu, and S. C. Zhu (2018) Interpretable convolutional neural networks. In CVPR, Cited by: §1.

Appendix A Training details

Different candidate databases are sampled randomly from the training dataset at each iteration. Training database size is chosen to fit the model to the memory of a single GPU. at inference is chosen sufficiently large to obtain high accuracy. Table 4 shows the database size for the datasets used in the experiments.

Dataset Encoder Database size
Training Inference
MNIST ResNet 1024 32768
Fashion-MNIST ResNet 1024 32768
CIFAR-10 ResNet 1024 32768
Fruits ResNet 256 4096
ISIC Melanoma ResNet 256 4096
DBPedia VDCNN 512 4096
Census Income LSTM 4096 15360
Table 4: Datasets and database size .

a.1 Image data

a.1.1 MNIST dataset

We apply random cropping after padding each side by 2 pixels and per image standardization. The base encoder uses a standard 32 layer ResNet architecture. The number of filters is initially 16 and doubled every 5 blocks. In each block, two convolutional layers are used to transform the input, and the transformed output is added to the input after a convolution. downsampling is applied by choosing the stride as 2 after and blocks. Each convolution is followed by batch normalization and ReLU nonlinearity. After the last convolution, average pooling is applied. The output is followed by a fully-connected layer of 256 units and ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fully-connected layer followed by ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =64 units and ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 6k iterations. The model is trained for 84k iterations. For prototypical learning model with softmax attention, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.8 applied every 8k iterations. The model is trained for 228k iterations. For prototypical learning model with sparsemax attention, the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.93 applied every 6k iterations. The model is trained for 228k iterations. All models use a batch size of 128 and gradient clipping above 20.

a.1.2 Fashion-MNIST dataset

We apply random cropping after padding each side by 2 pixels, random horizontal flipping, and per image standardization. The base encoder uses a standard 32 layer ResNet architecture, similar to our MNIST experiments. For the baseline encoder, the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.9 applied every 10k iterations. The model is trained for 332k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0007 and exponential decay is applied with a rate of 0.92 applied every 8k iterations. The model is trained for 450k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 392k iterations. For prototypical learning with sparsemax attention and sparsity regularization (with ), the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.94 applied every 8k iterations. is chosen when confidence regularization is applied. The model is trained for 440k iterations. All models use a batch size of 128 and gradient clipping above 20.

a.1.3 CIFAR-10 dataset

We apply random cropping after padding each side by 3 pixels, random horizontal flipping, random vertical flipping and per image standardization. The base encoder uses a standard 50 layer ResNet architecture. The number of filters is initially 16 and doubled every 8 blocks. In each block, two convolutional layers are used to transform the input, and the transformed output is added to the input after a convolution. downsampling is applied by choosing the stride as 2 after and blocks. Each convolution is followed by batch normalization and the ReLU nonlinearity. After the last convolution, average pooling is applied. The output is followed by a fully-connected layer of 256 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). The output is followed by a fully-connected layer of 512 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fully-connected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =128 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 940k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0035 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 625k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 905k iterations. For prototypical learning with sparsemax attention and sparsity regularization (with ), the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.95 applied every 12k iterations. is chosen when confidence regularization is applied. The model is trained for 450k iterations. All models use a batch size of 128 and gradient clipping above 20.

CIFAR-10 experiments with noisy labels.

For CIFAR-10 experiments with noisy labels for the base encoder we only optimize the learning parameters. Noisy labels are sampled uniformly from the set of labels excluding the correct one. The baseline model with noisy label ratio of 0.8 uses an initial learning rate of 0.001, decayed with a rate of 0.92 every 6k iterations, and is trained for 15k iterations. For the dropout approach, dropout with a rate of 0.1 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.85 every 8k iterations, and is trained for 24k iterations. The baseline model with noisy label ratio of 0.6 uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 12k iterations. For the dropout approach, dropout with a rate of 0.3 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 8k iterations, and is trained for 18k iterations. The baseline model with noisy label ratio of 0.4 uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 15k iterations. For the dropout approach, dropout with a rate of 0.5 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 18k iterations. For experiments for the prototypical learning model with sparsemax attention, we optimize the learning parameters and . For the model with noisy label ratio of 0.8, , initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.95 applied every 8k iterations. The model is trained for 108k iterations. For the model with noisy label ratio of 0.6, , initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 92k iterations. For the model with noisy label ratio of 0.4, , initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 6k iterations. The model is trained for 122k iterations.

a.1.4 Fruits dataset

We apply random cropping after padding each side by 5 pixels, random horizontal flipping, random vertical flipping and per image standardization. In the encoder, first, a downsampling with a convolutional layer is applied with a stride of 2, and using 16 filters, followed by a downsampling with max-pooling with a stride of 2. After obtaining the inputs, the a standard 32 layer ResNet architecture (similar to MNIST) is used, followed by a fully-connected layer of 128 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fully-connected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. W eight decay with a factor of 0.0001 is applied for the convolutional filters. The model uses a batch size of 128 and gradient clipping above 20.

a.1.5 ISIC Melanoma dataset

The ISIC Melanoma dataset is formed from the ISIC Archive (ISIC, 2016) that contains over 13k dermoscopic images collected from leading clinical centers internationally and acquired from a variety of devices within each center. The dataset consists of skin images with labels denoting whether they contain melanoma or are benign. We construct the training and validation dataset using 15122 images (13511 benign and 1611 melanoma cases), and the evaluation dataset using 3203 images (2867 benign and 336 melanoma). While training, benign cases are undersampled in each batch to have 0.6 ratio including candidate database sets at training and inference. All images are resized to pixels. We apply random cropping after padding each side by 8 pixels, random horizontal flipping, random vertical flipping and per image standardization. In the encoder, first, a downsampling with a convolutional layer is applied with a stride of 2, and using 16 filters, followed by a downsampling with max-pooling with a stride of 2. After obtaining the inputs, the base encoder uses a standard 50 layer ResNet architecture (similar to CIFAR10), followed by a fully-connected layer of 128 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fully-connected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 3k iterations. The model is trained for 220k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.9 applied every 3k iterations. The model is trained for 147k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.9 applied every 4k iterations. The model is trained for 166k iterations. All models use a batch size of 128 and gradient clipping above 20.

a.1.6 Animals with Attributes dataset

We train ProtoAttend with sparsemax attention using the features from a pre-trained ResNet-50 as provided in (Yeh et al., 2018). To map the pre-trained features, we simply insert a single fully-connected layer with 256 units with ReLU nonlinearity and layer normalization, followed by the individual fully-connected layers of keys, queries and values (16, 16 and 64 units respectively with ReLU nonlinearity). Sparsity regularization is applied with . We train the model for 70k iterations. The initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.8 applied every 10k iterations. A classification accuracy above 91% is obtained for the test set.

a.2 Text data

a.2.1 DBPedia dataset

There are 14 output classes: Company, Educational Institution, Artist, Athlete, Office Holder, Mean Of Transportation, Building, Natural Place, Village, Animal, Plant, Album, Film, Written Work. As the input, 16-dimensional trainable embeddings are mapped from the dictionary of 69 raw characters (Conneau et al., 2016). The maximum length is set to 448 and longer inputs are truncated while the shorter inputs are padded. The input embeddings are first transformed with a 1-D convolutional block consisting 64 filters with kernel width of 3 and stride of 2. Then, 8 convolution blocks as in (Conneau et al., 2016) are applied, with 64, 64, 128, 128, 256, 256, 512 and 512 filters respectively. All use the kernel width of 3, and after each two layers, max pooling is applied with kernel width of 3 and a stride of 2. All convolutions are followed by batch normalization and the ReLU nonlinearity. Convolutional filters use weight normalization with parameter 0.00001. The last convolution block is followed by k-max pooling with =8 (Conneau et al., 2016). Finally, we apply two fully-connected layers with 1024 hidden units. In contrast to (Conneau et al., 2016), we also use layer normalization (Lei Ba et al., 2016) after fully-connected layers as we observe this leads to more stable training behavior. Keys and queries are mapped from the output using a fully-connected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, initial learning rate is chosen as 0.0008 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 212k iterations. For prototypical learning model with softmax attention, the initial learning rate is chosen as 0.0008 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 146k iterations. For prototypical learning model with sparsemax attention, the initial learning rate is chosen as 0.0005 and exponential decay is applied with a rate of 0.82 applied every 8k iterations. The model is trained for 270k iterations. All models use a batch size of 128 and gradient clipping above 20. We do not apply any data augmentation.

a.3 Tabular data

a.3.1 Adult Census Income

There are two output classes: whether or not the annual income is above $50k. Categorical categories such as the ‘marital-status’ are mapped to multi-hot representations. Continuous variables are used after a fixed normalization transformation. For ‘age’, the transformation first subtracts 50 and then divides by 30. For ‘fnlwgt’, the transformation first takes the log, and then subtracts 9, and then divides by 3. For ‘education-num’, the transformation first subtracts 6 and then divides by 6. For ‘hours-per-week’, the transformation first subtracts 50 and then divides by 50. For ‘capital-gain’ and ‘capital-loss’, the normalization takes the log, and then subtracts 5, and then divides by 5. The concatenated features are then mapped to a 64 dimensional vector using a fully-connected layer, followed by the ReLU nonlinearity. The base encoder uses an LSTM architecture, with 4 timesteps. At each timestep, 64-dimensional inputs are applied after a dropout with rate 0.5. The output of the last timestep is used after applying a dropout with rate 0.5. Keys and queries are mapped from this output using a fully-connected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fully-connected layer of =16 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 2k iterations. The model is trained for 4.5k iterations. For the models with attention in prototypical learning framework, the initial learning rate is chosen as 0.0005 and exponential decay is applied with a rate of 0.92 applied every 2k iterations. The softmax attention model is trained for 13.5k iterations and the sparsemax attention model is trained for 11.5k iterations. For the model with sparsity regularization, the initial learning rate is 0.003 and exponential decay is applied with a rate of 0.7 applied every 2k iterations, and the model is trained for 7k iterations. All models use a batch size of 128 and gradient clipping above 20. We do not apply any data augmentation.

Appendix B Additional prototype examples

(a) With sparsemax
(b) With sparsemax and sparsity regularization
Figure 10: Example inputs and corresponding prototypes for CIFAR-10.

Fig. 10 exemplify prototypes for CIFAR-10. For most cases, we observe the similarity of discriminative features between inputs and prototypes. For example, the body figures of birds, the shape of tires, the face patterns of dogs, the body figures of frogs, the appearance of the background sky for planes, are among these features apparent in examples.

Figure 11: Example inputs and corresponding prototypes for DBPedia (with sparsemax).

Fig. 11 shows additional prototype examples for DBPedia dataset. Prototypes have very similar sentence structure, words and concepts, while categorizing the sentences into ontologies.

Fig. 12 shows example prototypes for ISIC Melanoma. In some cases, we observe the commonalities between input and prototypes that distinguish melanoma cases such as the non-circular geometry or irregularly-notched borders (Jerant et al., 2000). Compared to other datasets, ISIC Melonama dataset yields lower interpretable prototype quality on average. We hypothesize this to be due to the perceptual difficulty of the problem as well as the insufficient encoder performance shown by the lower classification accuracy (despite the acceptable AUC).

Figure 12: Example inputs and corresponding prototypes for ISIC Melanoma (with sparsemax attention).

Fig. 13 shows more comparison examples for prototypical learning framework with sparsemax attention vs. representer point selection (Yeh et al., 2018) on Animals with Attributes dataset. For some cases, including chimpanzee, zebra, dalmatian and tiger, ProtoAttend yields perceptually very similar samples. The similarity of the chimpanzee body form and the background, zebra patterns, dalmatian pattern on the grass, and tiger pattern and head pose, are prominent. Representer point selection fails to capture such similarity features as effectively. On the other hand, for bat, otter and wolf, the results are somewhat less satisfying. The wing part of the bat, multiple count of the otters with the background, and the color and furry head of the wolf seem to be captured, but with less apparent similarity than some other possible samples from the dataset. Representer point selection method also cannot be claimed to be successful in these cases. Lastly, for leopard, ProtoAttend only yields one non-zero prototype (which is indeed statistically rare given the model and sparsity choices). The pattern of the leopard image seems relevant, but it is also not fully satisfying to observe a single prototype that is not perceptually more similar. All of the test examples in Fig. 13 are classified correctly with our framework and all of the shown prototypes are also from the correct classes.

Figure 13: Relevant samples found by ProtoAttend with sparsemax attention vs. representer point selection (Yeh et al., 2018) for the examples from Supplementary Material of (Yeh et al., 2018).

Appendix C Comparison of confidence-controlled prediction for softmax vs. sparsemax

(a) MNIST
(b) Fashion MNIST
Figure 14: Accuracy vs. ratio of samples for (a) MNIST and (b) Fashion MNIST, for confidence levels between 0 and 0.999.
(a) DBpedia
(b) Adult Census Income
Figure 15: Accuracy vs. ratio of samples for (a) DBpedia and (b) Adult Census Income, for confidence levels between 0 and 0.999.

Figs. 14 and 15 show the accuracy vs. ratio of samples for softmax vs. sparsemax attention without confidence regularization. The baseline accuracy (at 100% prediction ratio) is higher for softmax attention for some datasets, whereas higher for sparsemax for some others. On the other hand, higher number of prototypes yielded by softmax attention results in a wider range for confidence-controlled prediction trade-off.

As an impactful case study, we consider melanoma detection problem with ISIC dataset (ISIC, 2016) in Supplementary Material. In medical diagnosis, it is strongly desired to maintain a sufficiently-high prediction performance, potentially by verifying the decisions of an AI systems by medical experts in the cases where the AI models are not confident. By refraining from some predictions, as shown in Fig. 16, we demonstrate unprecedentedly high AUC values without using transfer learning or highly-customized models (Haenssle et al., 2018).

Figure 16: Area-under-curve (AUC) vs. ratio of samples for ISIC Melanoma with softmax attention, for confidence values ranging between 0 and 0.99.

Appendix D Controlling sparsity via regularization

Figure 17: Number of training iterations vs. median number prototypes to explain 95% of the decision (in logarithmic scale), for Fashion-MNIST with softmax attention.

Fig. 17 shows the impact of sparsity regularization coefficient on training. By varying the value of , the number of prototypes can be efficiently controlled. For high values of sparsity regularization coefficient, the model gets stuck at a point where it is forced to make decision from a low number of prototypes before the encoder model is properly learned, hence typically yields considerably lower performance. We also observe sparsity mechanism via sparsemax attention to yield better performance than softmax attention with high sparsity regularization.

Appendix E Prototype quality

In general, the following scenarios may yield low prototype quality:

  1. [noitemsep,nolistsep]

  2. Lack of related samples in the candidate database.

  3. Perceptual difference between humans and encoders in determining discriminative features.

  4. High intra-class variability that makes training difficult.

  5. Imperfect encoder that cannot yield fully accurate representations of the input.

  6. Insufficiency of relational attention to determine weights from queries and keys.

  7. Inefficient decoupling between encoder & attention blocks and the final decision block.

There can be problem-dependent fundamental limitations on (1)-(3), whereas (4)-(6) are raised by choices of models and losses and can be further improved. We leave the quantification of prototype quality using information-theoretic metrics or discriminative neural networks to future work.

Appendix F Understanding misclassification cases

One of the benefits of prototypical learning is insights into wrong decision cases. Fig. 18 exemplifies prototypes with wrong labels, that give insights about why the model is confused about a particular input (e.g. due to similarity of the visual patterns). Such insights can be actionable to improve the model performance, such as adding more training samples for the confusing classes or modifying the loss functions.

Figure 18: Example prototypes with wrong labels for CIFAR-10.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
391948
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description