ProtoAttend: AttentionBased Prototypical Learning
Abstract
We propose a novel inherently interpretable machine learning method that bases decisions on few relevant examples that we call prototypes. Our method, ProtoAttend, can be integrated into a wide range of neural network architectures including pretrained models. It utilizes an attention mechanism that relates the encoded representations to samples in order to determine prototypes. The resulting model outperforms state of the art in three high impact problems without sacrificing accuracy of the original model: (1) it enables highquality interpretability that outputs samples most relevant to the decisionmaking (i.e. a samplebased interpretability method); (2) it achieves state of the art confidence estimation by quantifying the mismatch across prototype labels; and (3) it obtains state of the art in distribution mismatch detection. All this can be achieved with minimal additional test time and a practically viable training time computational cost.
1 Introduction
Deep neural networks have been pushing the frontiers of artificial intelligence (AI) by yielding excellent performance in numerous tasks, from understanding images (He et al., 2016) to text (Conneau et al., 2016). Yet, high performance is not always a sufficient factor  as some realworld deployment scenarios might necessitate that an ideal AI system is ‘interpretable’, such that it builds trust by explaining rationales behind decisions, allow detection of common failure cases and biases, and refrains from making decisions without sufficient confidence. In their conventional form, deep neural networks are considered as blackbox models – they are controlled by complex nonlinear interactions between many parameters that are difficult to understand. There are numerous approaches, e.g. (Kim et al., 2018a; Erhan et al., 2009; Zeiler and Fergus, 2013; Simonyan et al., 2013), that bring posthoc explainability of decisions to alreadytrained models. Yet, these have the fundamental limitation that the models are not designed for interpretability. There are also approaches on the redesign of neural networks towards making them inherentlyinterpretable, as in this paper. Some notable ones include sequential attention (Bahdanau et al., 2015), capsule networks (Sabour et al., 2017), and interpretable convolutional filters (Zhang et al., 2018).
We focus on inherentlyinterpretable deep neural network modeling with the foundations of prototypical learning. Prototypical learning decomposes decision making into known samples (see Fig. 1), referred here as prototypes. We base our method on the principle that prototypes should constitute a minimal subset of samples with high interpretable value that can serve as a distillation or condensed view of a dataset (Bien and Tibshirani, 2012). Given that the number of objects a human can interpret is limited (Miller, 1956), outputting few prototypes can be an effective approach for humans to understand the AI model behavior. In addition to such interpretability, prototypical learning: (1) provides an efficient confidence metric by measuring mismatches in prototype labels, allowing performance to be improved by refraining from making predictions in the absence of sufficient confidence, (2) helps detect deviations in the test distribution by measuring mismatches in prototype labels that represent the support of the training dataset, and (3) enables performance in the high label noise regime to be improved by controlling the number of selected prototypes. Given these motivations, prototypes should be controllable in number, and should be perceptually relevant to the input in explaining the decision making task. Prototype selection in its naive form is computationally expensive and perceptually challenging (Bien and Tibshirani, 2012). We design ProtoAttend to address this problem in an efficient way. Our contributions can be summarized as follows:

[noitemsep, nolistsep, leftmargin=*]

We propose a novel method, ProtoAttend, for selecting inputdependent prototypes based on an attention mechanism between the input and prototype candidates. ProtoAttend is modelagnostic and can even be integrated with pretrained models.

ProtoAttend allows interpreting the contribution of each prototype via the attention outputs.

For a ‘condensed view’, we demonstrate that sparsity in weights can be efficiently imposed via the choice of the attention normalization and additional regularization.

On image, text and tabular data, we demonstrate the four key benefits of ProtoAttend: interpretability, confidence control, diagnosis of distribution mismatch, and robustness against label noise. ProtoAttend yields superior quality for samplebased interpretability, bettercalibrated confidence scoring, and more sensitive outofdistribution detection compared to alternative approaches.

ProtoAttend enables all these benefits via the same architecture and method, while maintaining comparable overall accuracy.
2 Related Work
Prototypical learning: The principles of ProtoAttend are inspired by (Bien and Tibshirani, 2012). They formulate prototype selection as an integer program and solve it using a greedy approach with linear program relaxation. It seems unclear whether such approaches can be efficiently adopted to deep learning. (Chen et al., 2018) and (Li et al., 2018) introduce a prototype layer for interpretability by replacing the conventional inner product with a distance computation for perceptual similarity. In contrast, our method uses an attention mechanism to quantify perceptual similarity and can choose inputdependent prototypes from a largescale candidate database. (Yeh et al., 2018) decomposes the prediction into a linear combination of activations of training points for interpretability using representer values. The linear decomposition idea also exists in ProtoAttend, but the weights are learned via an attention mechanism and sparsity is encouraged in the decomposition. In (Koh and Liang, 2017), the training points that are the most responsible for a given prediction are identified using influence functions via oracle access to gradients and Hessianvector products.
Metric learning: Metric learning aims to find an embedding representation of the data where similar data points are close and dissimilar data pointers are far from each other. ProtoAttend is motivated by efficient learning of such an embedding space which can be used to decompose decisions. Metric learning for deep neural networks is typically based on modifications to the objective function, such as using triplet loss and Npair loss (Sohn, 2016; Cui et al., 2016; Hoffer and Ailon, 2014). These yield perceptually meaningful embedding spaces yet typically require a large subset of nearest neighbors to avoid degradation in performance (Cui et al., 2016). (Kim et al., 2018b) proposes a deep metric learning framework which employs an attentionbased ensemble with a divergence loss so that each learner can attend to different parts of the object. Our method has metric learning capabilities like relating similar data points, but also performs well on the ultimate supervised learning task.
Attentionbased fewshot learning: Some of our inspirations are based on recent advances in attentionbased fewshot learning. In (Vinyals et al., 2016), an attention mechanism is used to relate an example with candidate examples from a support set using a weighted nearestneighbor classifier applied within an embedding space. In (Ren et al., 2018), incremental fewshot learning is implemented using an attention attractor network on the encoded and support sets. In (Snell et al., 2017), a nonlinear mapping is learned to determine the prototype of a class as the mean of its support set in the embedding space. During training, the support set is randomly sampled to mimic the inference task. Overall, the attention mechanism in our method follows related principles but fundamentally differs in that fewshot learning aims for generalization to unseen classes whereas the goal of our method is robust and interpretable learning for seen classes.
Uncertainty and confidence estimation: ProtoAttend takes a novel perspective on the perennial problem of quantifying how much deep neural networks’ predictions can be trusted. Common approaches are based on using the scores from the prediction model, such as the probabilities from the softmax layer of a neural network, yet it has been shown that the raw confidence values are typically poorly calibrated (Guo et al., 2017). Ensemble of models (Lakshminarayanan et al., 2017) is one of the simplest and most efficient approaches, but significantly increases complexity and decreased interpretability. In (Papernot and McDaniel, 2018), the intermediate representations of the network are used to define a distance metric, and a confidence metric is proposed based on the conformity of the neighbors. (Jiang et al., 2018), proposes a confidence metric based on the agreement between the classifier and a modified nearestneighbor classifier on the test sample. In (DeVries and Taylor, 2018), direct inference of confidence output is considered with a modified loss. Another direction of uncertainty and confidence estimation is Bayesian neural networks that return a distribution over the outputs (Kendall and Gal, 2017) (Mullachery et al., 2018) (Kendall and Gal, 2017).
3 ProtoAttend: Attentionbased Prototypical Learning
Consider a training set with labels, . Conventional supervised learning aims to learn a model that minimizes a predefined loss ^{1}^{1}1 represents the trainable parameters for and is sometimes not show for notation convenience. at each iteration, where is the batch size for training. Our goal is to impose that decision making should be based on only a small number of training examples, i.e. prototypes, such that their linear superposition in an embedding space can yield the overall decision and the superposition weights correspond to their importance. Towards this goal, we propose defining a solutions to prototypical learning with the following six principles:

[label=., noitemsep, nolistsep]

encodes all relevant information of for the final decision. considers the global distribution of the samples, i.e. learns from all . Although all the information in training dataset is embodied in the weights of the encoder^{2}^{2}2Training of may also involve initializing with pretrained models or transfer learning., we construct the learning method in such a way that decision is dominated by the prototypes with high weights.

From the encoded information, we can find a decision function so that the mapping is close to the ground truth , in a consistent way with conventional supervised learning.

Given candidates to select the prototypes from, there exists weights (where and ), such that the decision (where ) is close to the ground truth .

When the linear combination is considered, prototypes with higher weights have higher contribution in the decision .

The weights should be sparse – only a controllable amount of weights should be nonzero. Ideally, there exists an efficient mechanism for outputting to control the sparsity without significantly affecting performance.

The weights depend on the relation between input and the candidate samples, , based on their perceptual relation for decision making. We do not introduce any heuristic relatedness metric such as distances in the representation space, but we allow the model to learn the relation function that helps the overall performance.
Learning involves optimization of the parameters of the corresponding functions. If the proposed principles (such as reasoning from the linear combination of embeddings or assigning relevance to the weights) are not imposed during training but only at inference, a high performance cannot be obtained due to the traintest mismatch, as the intermediate representations can be learned in an arbitrary way without any necessities to satisfy them.^{3}^{3}3For example, commonlyused distance metrics in the representation spaces fail at determining perceptual relevance of between samples when the model is trained in a vanilla way (Sitawarin and Wagner, 2019). The subsequent section presents ProtoAttend and training procedure to implement it.
3.1 Network architecture and training
The principles above are conditioned on efficient learning of an encoding function to encode the relevant information for decision making, a relation function to determine the prototype weights, and a final decision making block to return the output. Conventional supervised learning comprises the encoding and decision blocks. On the other hand, it is challenging to design a learning method with a relation function with a reasonable complexity. To this end, we adapt the idea of attention (Corbetta and Shulman, 2002; Vaswani et al., 2017), where the model focuses on an adaptive small portion of input while making the decision. Different from conventional employment of attention in sequence or visual learning, we propose to use attention at sample level, such that the attention mechanism is used to determine the prototype weights by relating the input and the candidate samples via alignment of their keys and queries. Fig. 2 shows the proposed architecture for training and inference. The three main blocks are described below:
Encoder: A trainable encoder is employed to transform input samples (note that may be 1 at inference) and samples from the database of prototype candidates (note that may be as large as the entire training dataset at inference) into keys, queries and values. The encoder is shared and jointly updated for the input samples and prototype candidate database, to learn a common representation space for the values. The encoder architecture can be based on any trainable discriminative feature mapping function, e.g. ResNet (He et al., 2016) for images, with the modification of generating three types of embeddings. For mapping of the last encoder layer to key, query and value embeddings, we simply use a single fullyconnected layer with a nonlinearity, separately for each.^{4}^{4}4There are other viable options for the mapping but we restrict it to a single layer to minimize the additional number of trainable parameters, which becomes negligible in most cases. For input samples, and denote the values and queries, and for candidate database samples and denote the keys and values.
Relational attention: The relational attention yields the weight between the sample and candidate, , via alignment of the corresponding key and query in dotproduct attention form^{5}^{5}5We use to denote the row of .:
(1) 
where is a normalization function to satisfy and for which we consider softmax and sparsemax (Martins and Astudillo, 2016)^{6}^{6}6Sparsemax encourages sparsity by mapping the Euclidean projection onto the probabilistic simplex.. The choice of the normalization function is an efficient mechanism to control the sparsity of the prototype weights, as demonstrated in experiments. Note that the relational attention mechanism does not introduce any extra trainable parameters.
Decision making: The final decision block simply consists of a linear mapping from a convex combination of values that results in the output . Consider the convex combination of value embeddings, parameterized by :
(2) 
For , is the conventional supervised learning loss (ignoring the relational attention mechanism) that can only impose principles (i) and (ii), but not the principles (iii)(vi). A high accuracy for merely indicates that the value embedding space represents each input sample accurately. For , encourages the principles (i), (iii)(iv), but not the principles (ii) and (vi).^{7}^{7}7For example, simply assigning nonzero weights to another predetermined class, prototypical learning method can obtain perfect accuracy, but the assignment of predetermined class would be arbitrary. A high accuracy for indicates that the linear combination of value embeddings accurately maps to the decision. For (vi), we propose that there should be a similar output mapping for the input and prototypes, for which we encourage high accuracy for both and with a loss term that is a mixture of and or guidance with an intermediate term, as , is required. Lastly, when , we obtain the condition that the input sample itself has the largest contribution in the linear combination. Intuitively, the sample itself should be more relevant for the output compared to other samples, so the principles (iii) and (iv) can be encouraged. We propose and compare different training objective functions in Table 1. We observe that the last four are all viable options as the training objective, with similar performance. We choose the last one for the rest of the experiments, as in some cases, slightly better prototypes are observed qualitatively (see Sect. 5.2 for further discussion).
Training objective function  Acc. % for  Acc. % for  
94.28  13.13  0.029  0.194  
10.92  94.21  0.103  0.002  
94.01  94.25  0.927  0.049  
+  94.37  94.38  0.931  0.047 
+  94.14  94.18  0.927  0.049 
+ +  94.37  94.45  0.928  0.047 
To control the sparsity of the weights (beyond the choice of the attention operation), we also propose a sparsity regularization term with a coefficient in the form of entropy, , where is a small number for numerical stability. is minimized when has only 1 nonzero value.
3.2 Confidence scoring using prototypes
ProtoAttend provides a linear decomposition (via value embeddings) of the decision into prototypes that have known labels. Ideally, labels of the prototypes should all be the same as the labels of the input. When prototypes with high weights belong to the same class, the model shall be more confident and a correct classification result is expected, whereas in the cases of disagreement between prototype labels, the model shall be less confident and the likelihood of a wrong prediction is higher. With the motivation of separating correct vs. incorrect decisions via its value, we propose a confidence score based on the agreement between the prototypes:
(3) 
where is the indicator function. Table 1 shows the significant difference of the average confidence metric between correct vs. incorrect classification cases for the test dataset, as desired. In Fig. 3, the impact of confidence on accuracy is further analyzed with the reliability diagram as in (Papernot and McDaniel, 2018). When test samples are binned according to their confidence, it is observed that the bins with higher confidence yield much higher accuracy. There are small number of samples in the bins with lower confidence, and those tend to be the incorrect classification cases. In Section 4.4, the efficacy of confidence score in separating correct vs. incorrect classification is experimented in confidencecontrolled prediction setting, demonstrating how much the prediction accuracy can be improved by refraining from small number of samples with low confidence at test time.
To further encourage confidence during training, we also consider a regularization term with a coefficient . is minimized when all prototypes with are from the same ground truth class with output .^{8}^{8}8Note that the gradients of this regularization term with respect to is either 0 or 1 and it is often insufficient to train the model itself from scratch. But it is observed to provide further improvements in some cases.
4 Experiments
4.1 Setup
We demonstrate the results of ProtoAttend for image, text and tabular data classification problems with different encoder architectures (see Supplementary Material for details). Outputs of the encoders are mapped to queries, keys and values using a fullyconnected layer followed by ReLU. For values, layer normalization (Lei Ba et al., 2016) is employed for more stable training. A fullyconnected layer is used in the decision making block, yielding logits for determining the estimated class. Softmax cross entropy loss is used as . Adam optimization algorithm is employed (Kingma and Ba, 2014) with exponential learning rate decay (with parameters optimized on a validation set). For image encoding, unless specified, we use the standard ResNet model (He et al., 2016). For text encoding, we use the very deep convolutional neural network (VDCNN) (Conneau et al., 2016) model, inputting sequence of raw characters. For tabular data encoding, we use an LSTM model (Hochreiter and Schmidhuber, 1997), which inputs the feature embeddings at every timestep. See Supplementary Material for implementation details, additional results and discussions.
4.2 Sparse explanations of decisions
Dataset  Method  Acc. %  No. of prototypes  
50 %  90 %  95 %  
MNIST  Baseline enc.  99.70    
Softmax attn.  99.66  365  1324  1648  
Sparsemax attn.  99.69  2  4  5  
FashionMNIST  Baseline enc.  94.74    
Softmax attn.  94.42  712  2320  2702  
Sparsemax attn.  94.42  4  10  11  
Sparsemax attn. + sparsity reg.  94.47  1  2  2  
CIFAR10  Baseline enc.  91.97    
Softmax attn.  91.69  317  1453  1898  
Sparsemax attn.  91.44  5  14  16  
Sparsemax attn. + sparsity reg.  91.26  2  3  4  
DBPedia  Baseline enc.  98.25    
Softmax attn.  98.20  63  190  225  
Sparsemax attn.  97.74  2  4  4  
Income  Baseline enc.  85.68    
Softmax attn.  85.64  2263  9610  12419  
Sparsemax attn.  85.58  20  57  67  
Sparsemax attn. + sparsity reg.  85.41  3  6  7 
We foremost demonstrate that our inherentlyinterpretable model design does not cause significant degradation in performance. Table 2 shows the accuracy and the median number of prototypes required to add up to a particular portion of the decision^{9}^{9}9E.g. if the prototype weights are [0.2, 0.15, 0.15, 0.25, 0.1, 0.05, 0.28, 0.02], then 2 prototypes are required for 50% of the decision, 6 for 90% and 7 for 95%. for different prototypical learning cases. In all cases, very small accuracy gap is observed with the baseline encoder that is trained in conventional supervised learning way. The attention normalization function and sparsity regularization are efficient mechanisms to control the sparsity – the number of prototypes required is much lower with sparsemax attention compared to softmax attention and can be further reduced with sparsity regularization (see Supplementary Material for details). With a small decrease in performance, the number of prototypes can be reduced to just a handful.^{10}^{10}10We observe that excessively high sparsity (e.g. to yield 12 prototypes in most cases) may sometimes decrease the quality of prototypes due to overfitting to discriminative features that are less perceptually meaningful. There is difference between datasets, as intuitively expected from the discrepancy in the degree of similarity between the intraclass samples.
Figs. 4, 5 and 6 exemplify prototypes for image, text and tabular data. In general, perceptuallysimilar samples are chosen as the prototypes with the largest weights. We also compare the relevant samples found by ProtoAttend with the methods of representer point selection (Yeh et al., 2018) and influence functions (Koh and Liang, 2017) (see Supplementary Material for details) on Animals with Attributes dataset. As shown in Fig. 7, our method finds qualitatively more relevant samples. This case also exemplifies the potential of our method for integration into pretrained models by addition of simple layers for key, query and value generation.
4.3 Robustness to label noise
Noise level  Test accuracy %  

Baseline  Dropout  ProtoAttend  
0.8  57.02  56.76  60.50 
0.6  71.27  72.15  74.67 
0.4  77.47  78.99  80.04 
As prototypical learning with sparsemax attention aims to extract decisionmaking information from a small subset of training samples, it can be used to improve performance when the training dataset contains noisy labels (see Table 3). The optimal value^{11}^{11}11For a fair comparison, we reoptimize the learning rate parameters on a separate validation set. of increases with higher noisy label ratios, underlining the increasing importance of sparse learning.
4.4 Confidencecontrolled prediction
By varying the threshold for the confidence metric, a tradeoff can be obtained for what ratio of the test samples that the model makes a prediction for vs. the overall accuracy it obtains on the samples above that threshold.^{12}^{12}12Note that this tradeoff is often more meaningful to consider rather than the metrics based on the actual value of confidence score itself, as methods may differ in how they define the confidence metric, and thus yield very different ranges and distributions for it. Figs. 8(a) and 8(b) demonstrate this tradeoff and compare it to alternative methods. The sharper slope of the plots show that our method is superior to dkNN (Papernot and McDaniel, 2018) and trust score (Jiang et al., 2018), the methods based on quantifying the mismatch with nearestneighbor samples, in terms of finding related samples. Although the baseline accuracy is higher with 4 ensemble networks obtained via deep ensemble (Lakshminarayanan et al., 2017), our method utilizes a single network and the additional accuracy gains by refraining from uncertain predictions is similar to our approach as shown by the similar slopes of the curves.
Overall, the baseline accuracy can be significantly improved by making less predictions. Compared to the state of the art models, our canonical method with simple and small models shows similar accuracy by making slightly fewer predictions – e.g. for MNIST, (Wan et al., 2013) achieves 0.21% error rate, that is obtained by our method refraining from only 0.45% of predictions using ResNet32 and for DBpedia, (Sachan and Petuum, 2018) achieves 0.91% error, that is obtained by our method refraining from 3% of predictions using 9layer VDCNN. In general, the smaller the number of prototypes, the smaller the tradeoff space. Thus, softmax attention (which normally results in more prototypes) is better suited for confidencecontrolled prediction compared to sparsemax (see Supplementary Material for more comparisons).
4.5 Outofdistribution samples
Wellcalibrated confidence scores at inference can be used to detect deviations from the training dataset. As the test distribution deviates from the training distribution, prototype weights tend to mismatch more and yield lower confidence scores. Fig. 9 (a) shows the ratio of samples above a certain confidence level as the test dataset deviates. Rotations deviate the distribution of test images from the training images, and cause significant degradation in confidence scores, as well as the overall accuracy. On the other hand, using test image from a different dataset, degrade them even further. Next, Fig. 9 (b) shows quantification of outofdistribution detection with prototypical learning, using the method from (Hendrycks and Gimpel, 2016). ProtoAttend yields an AUC of 0.838, being on par with thestate of the art approaches (Hendrycks et al., ).
5 Computational Cost
ProtoAttend requires only a very small increase in the number of learning parameters (merely two extra small matrices for the fullyconnected layers to obtain queries and keys). However, it does require a longer training time and has higher memory requirements to process the candidate database. At inference, keys and values for the candidate database can be computed only once and integrated into the model. Thus, the overhead merely becomes the computation of attention outputs (e.g. for CIFAR10 model, the attention overhead at inference is less than 0.6 MFLOPs, orders of magnitude lower than the computational complexity of a ResNet model). During training on the other hand, both forward and backward propagation steps for the encoder need to be computed for all candidate samples and the total time is higher (e.g. 4.45 times slower to train until convergence for CIFAR10 compared to the conventional supervised learning). The size of the candidate database is limited by the memory of the processor, so in practice we sample different candidate databases randomly from the training dataset at each iteration. For faster training, data and model parallelism approaches are straightforward to implement – e.g., different processors can focus on different samples, or they can focus on different parts of the convolution or inner product operations. Further computationallyefficient approaches may involve less frequent updates for candidate queries and values.
6 Conclusions
We propose an attentionbased prototypical learning method, ProtoAttend, and demonstrate its usefulness for a wide range of problems on image, text and tabular data. By adding a relational attention mechanism to an encoder, prototypical learning enables novel capabilities. With sparsemax attention, it can base the learning on a few relevant samples that can be returned at inference for interpretability, and can also improves robustness to label noise. With softmax attention, it enables confidencecontrolled prediction that can outperform state of the art results with simple architectures by simply making slightly fewer predictions, as well as enables detecting deviations from the training data. All these capabilities are achieved without sacrificing overall accuracy of the base model.
7 Acknowledgements
Discussions with Zizhao Zhang, ChihKuan Yeh, Nicolas Papernot, Ryan Takasugi, Andrei Kouznetsov, and Andrew Moore are gratefully acknowledged.
References
 A Closer Look at Memorization in Deep Networks. arXiv:1706.05394. Cited by: Table 3.
 Neural machine translation by jointly learning to align and translate. In ICLR, Cited by: §1.
 Prototype selection for interpretable classification. arXiv:1202.5933. Cited by: §1, §2.
 This looks like that: deep learning for interpretable image recognition. arXiv:1806.10574. Cited by: §2.
 Very deep convolutional networks for natural language processing. arXiv:1606.01781. Cited by: §A.2.1, §1, §4.1.
 Control of goaldirected and stimulusdriven attention in the brain. Nature Reviews Neuroscience 3, pp. 201–215. Cited by: §3.1.
 Finegrained categorization and dataset bootstrapping using deep metric learning with humans in the loop. CVPR. Cited by: §2.
 Learning Confidence for OutofDistribution Detection in Neural Networks. arXiv:1802.04865. Cited by: §2.
 Visualizing higherlayer features of a deep network. In Technical report, Cited by: §1.
 On calibration of modern neural networks. arXiv:1706.04599. Cited by: §2.
 Man against machine: diagnostic performance of a deep learning convolutional neural network for dermoscopic melanoma recognition in comparison to 58 dermatologists. Annals of Oncology 29 (8), pp. 1836–1842. Cited by: Appendix C.
 Deep residual learning for image recognition. In CVPR, Cited by: §1, §3.1, §4.1.
 A baseline for detecting misclassified and outofdistribution examples in neural networks. arXiv:1610.02136. Cited by: Figure 9, §4.5.
 [14] Deep anomaly detection with outlier exposure. arXiv:1812.04606. Cited by: §4.5.
 Long shortterm memory. Neural Computation 9 (8), pp. 1735–1780. Cited by: §4.1.
 Deep metric learning using triplet network. arXiv:1412.6622. Cited by: §2.
 ISIC Archive. External Links: Link Cited by: §A.1.5, Appendix C.
 Early detection and treatment of skin cancer. Am Fam Physician. Cited by: Appendix B.
 To trust or not to trust a classifier. In NIPS, Cited by: §2, Figure 8, §4.4.
 What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?. In arXiv:1703.04977, Cited by: §2.
 What uncertainties do we need in bayesian deep learning for computer vision?. In NIPS, Cited by: §2.
 Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV). In ICML, Cited by: §1.
 Attentionbased ensemble for deep metric learning. In ECCV, Cited by: §2.
 Adam: A method for stochastic optimization. In ICLR, Cited by: §4.1.
 Understanding Blackbox Predictions via Influence Functions. In ICML, Cited by: §2, Figure 7, §4.2.
 Simple and scalable predictive uncertainty estimation using deep ensembles. In NIPS, Cited by: §2, Figure 8, §4.4.
 Layer Normalization. arXiv:1607.06450. Cited by: §A.1.1, §A.1.3, §A.1.4, §A.1.5, §A.2.1, §4.1.
 Deep learning for casebased reasoning through prototypes: A neural network that explains its predictions. In AAAI, Cited by: §2.
 From softmax to sparsemax: A sparse model of attention and multilabel classification. In MLR, Cited by: §3.1.
 The magical number seven, plus or minus 2: some limits on our capacity for processing information. Psychological review 63, pp. 81–97. Cited by: §1.
 Bayesian neural networks. arXiv:1801.07710. Cited by: §2.
 Deep knearest neighbors: towards confident, interpretable and robust deep learning. arXiv:1803.04765. Cited by: §2, Figure 3, §3.2, Figure 8, §4.4.
 Incremental fewshot learning with attention attractor networks. arXiv:1810.07218. Cited by: §2.
 Dynamic routing between capsules. In NIPS, Cited by: §1.
 Revisiting lstm networks for semisupervised text classification via mixed objective function. In KDD, Cited by: §4.4.
 Deep inside convolutional networks: visualising image classification models and saliency maps. arXiv:1312.6034. Cited by: §1.
 On the robustness of deep knearest neighbors. arXiv:1903.08333. Cited by: footnote 3.
 Prototypical networks for fewshot learning. In NIPS, Cited by: §2.
 Improved deep metric learning with multiclass npair loss objective. In NIPS, Cited by: §2.
 Attention is all you need. arXiv:1706.03762. Cited by: §3.1.
 Matching networks for one shot learning. In NIPS, Cited by: §2.
 Regularization of neural networks using dropconnect. In ICML, Cited by: §4.4.
 Representer point selection for explaining deep neural networks. arXiv:1811.09720. Cited by: §A.1.6, Figure 13, Appendix B, §2, Figure 7, §4.2.
 Visualizing and understanding convolutional networks. arXiv:1311.2901. Cited by: §1.
 Interpretable convolutional neural networks. In CVPR, Cited by: §1.
Appendix A Training details
Different candidate databases are sampled randomly from the training dataset at each iteration. Training database size is chosen to fit the model to the memory of a single GPU. at inference is chosen sufficiently large to obtain high accuracy. Table 4 shows the database size for the datasets used in the experiments.
Dataset  Encoder  Database size  

Training  Inference  
MNIST  ResNet  1024  32768 
FashionMNIST  ResNet  1024  32768 
CIFAR10  ResNet  1024  32768 
Fruits  ResNet  256  4096 
ISIC Melanoma  ResNet  256  4096 
DBPedia  VDCNN  512  4096 
Census Income  LSTM  4096  15360 
a.1 Image data
a.1.1 MNIST dataset
We apply random cropping after padding each side by 2 pixels and per image standardization. The base encoder uses a standard 32 layer ResNet architecture. The number of filters is initially 16 and doubled every 5 blocks. In each block, two convolutional layers are used to transform the input, and the transformed output is added to the input after a convolution. downsampling is applied by choosing the stride as 2 after and blocks. Each convolution is followed by batch normalization and ReLU nonlinearity. After the last convolution, average pooling is applied. The output is followed by a fullyconnected layer of 256 units and ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fullyconnected layer followed by ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =64 units and ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 6k iterations. The model is trained for 84k iterations. For prototypical learning model with softmax attention, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.8 applied every 8k iterations. The model is trained for 228k iterations. For prototypical learning model with sparsemax attention, the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.93 applied every 6k iterations. The model is trained for 228k iterations. All models use a batch size of 128 and gradient clipping above 20.
a.1.2 FashionMNIST dataset
We apply random cropping after padding each side by 2 pixels, random horizontal flipping, and per image standardization. The base encoder uses a standard 32 layer ResNet architecture, similar to our MNIST experiments. For the baseline encoder, the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.9 applied every 10k iterations. The model is trained for 332k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0007 and exponential decay is applied with a rate of 0.92 applied every 8k iterations. The model is trained for 450k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 392k iterations. For prototypical learning with sparsemax attention and sparsity regularization (with ), the initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.94 applied every 8k iterations. is chosen when confidence regularization is applied. The model is trained for 440k iterations. All models use a batch size of 128 and gradient clipping above 20.
a.1.3 CIFAR10 dataset
We apply random cropping after padding each side by 3 pixels, random horizontal flipping, random vertical flipping and per image standardization. The base encoder uses a standard 50 layer ResNet architecture. The number of filters is initially 16 and doubled every 8 blocks. In each block, two convolutional layers are used to transform the input, and the transformed output is added to the input after a convolution. downsampling is applied by choosing the stride as 2 after and blocks. Each convolution is followed by batch normalization and the ReLU nonlinearity. After the last convolution, average pooling is applied. The output is followed by a fullyconnected layer of 256 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). The output is followed by a fullyconnected layer of 512 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fullyconnected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =128 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 940k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0035 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 625k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.95 applied every 10k iterations. The model is trained for 905k iterations. For prototypical learning with sparsemax attention and sparsity regularization (with ), the initial learning rate is chosen as 0.0015 and exponential decay is applied with a rate of 0.95 applied every 12k iterations. is chosen when confidence regularization is applied. The model is trained for 450k iterations. All models use a batch size of 128 and gradient clipping above 20.
CIFAR10 experiments with noisy labels.
For CIFAR10 experiments with noisy labels for the base encoder we only optimize the learning parameters. Noisy labels are sampled uniformly from the set of labels excluding the correct one. The baseline model with noisy label ratio of 0.8 uses an initial learning rate of 0.001, decayed with a rate of 0.92 every 6k iterations, and is trained for 15k iterations. For the dropout approach, dropout with a rate of 0.1 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.85 every 8k iterations, and is trained for 24k iterations. The baseline model with noisy label ratio of 0.6 uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 12k iterations. For the dropout approach, dropout with a rate of 0.3 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 8k iterations, and is trained for 18k iterations. The baseline model with noisy label ratio of 0.4 uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 15k iterations. For the dropout approach, dropout with a rate of 0.5 is applied, and the model uses an initial learning rate of 0.002, decayed with a rate of 0.92 every 6k iterations, and is trained for 18k iterations. For experiments for the prototypical learning model with sparsemax attention, we optimize the learning parameters and . For the model with noisy label ratio of 0.8, , initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.95 applied every 8k iterations. The model is trained for 108k iterations. For the model with noisy label ratio of 0.6, , initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 92k iterations. For the model with noisy label ratio of 0.4, , initial learning rate is chosen as 0.001 and exponential decay is applied with a rate of 0.9 applied every 6k iterations. The model is trained for 122k iterations.
a.1.4 Fruits dataset
We apply random cropping after padding each side by 5 pixels, random horizontal flipping, random vertical flipping and per image standardization. In the encoder, first, a downsampling with a convolutional layer is applied with a stride of 2, and using 16 filters, followed by a downsampling with maxpooling with a stride of 2. After obtaining the inputs, the a standard 32 layer ResNet architecture (similar to MNIST) is used, followed by a fullyconnected layer of 128 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fullyconnected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. W eight decay with a factor of 0.0001 is applied for the convolutional filters. The model uses a batch size of 128 and gradient clipping above 20.
a.1.5 ISIC Melanoma dataset
The ISIC Melanoma dataset is formed from the ISIC Archive (ISIC, 2016) that contains over 13k dermoscopic images collected from leading clinical centers internationally and acquired from a variety of devices within each center. The dataset consists of skin images with labels denoting whether they contain melanoma or are benign. We construct the training and validation dataset using 15122 images (13511 benign and 1611 melanoma cases), and the evaluation dataset using 3203 images (2867 benign and 336 melanoma). While training, benign cases are undersampled in each batch to have 0.6 ratio including candidate database sets at training and inference. All images are resized to pixels. We apply random cropping after padding each side by 8 pixels, random horizontal flipping, random vertical flipping and per image standardization. In the encoder, first, a downsampling with a convolutional layer is applied with a stride of 2, and using 16 filters, followed by a downsampling with maxpooling with a stride of 2. After obtaining the inputs, the base encoder uses a standard 50 layer ResNet architecture (similar to CIFAR10), followed by a fullyconnected layer of 128 units and the ReLU nonlinearity, followed by layer normalization (Lei Ba et al., 2016). Keys and queries are mapped from the output using a fullyconnected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 3k iterations. The model is trained for 220k iterations. For prototypical learning with softmax attention, the initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.9 applied every 3k iterations. The model is trained for 147k iterations. For prototypical learning with sparsemax attention, the initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.9 applied every 4k iterations. The model is trained for 166k iterations. All models use a batch size of 128 and gradient clipping above 20.
a.1.6 Animals with Attributes dataset
We train ProtoAttend with sparsemax attention using the features from a pretrained ResNet50 as provided in (Yeh et al., 2018). To map the pretrained features, we simply insert a single fullyconnected layer with 256 units with ReLU nonlinearity and layer normalization, followed by the individual fullyconnected layers of keys, queries and values (16, 16 and 64 units respectively with ReLU nonlinearity). Sparsity regularization is applied with . We train the model for 70k iterations. The initial learning rate is chosen as 0.0006 and exponential decay is applied with a rate of 0.8 applied every 10k iterations. A classification accuracy above 91% is obtained for the test set.
a.2 Text data
a.2.1 DBPedia dataset
There are 14 output classes: Company, Educational Institution, Artist, Athlete, Office Holder, Mean Of Transportation, Building, Natural Place, Village, Animal, Plant, Album, Film, Written Work. As the input, 16dimensional trainable embeddings are mapped from the dictionary of 69 raw characters (Conneau et al., 2016). The maximum length is set to 448 and longer inputs are truncated while the shorter inputs are padded. The input embeddings are first transformed with a 1D convolutional block consisting 64 filters with kernel width of 3 and stride of 2. Then, 8 convolution blocks as in (Conneau et al., 2016) are applied, with 64, 64, 128, 128, 256, 256, 512 and 512 filters respectively. All use the kernel width of 3, and after each two layers, max pooling is applied with kernel width of 3 and a stride of 2. All convolutions are followed by batch normalization and the ReLU nonlinearity. Convolutional filters use weight normalization with parameter 0.00001. The last convolution block is followed by kmax pooling with =8 (Conneau et al., 2016). Finally, we apply two fullyconnected layers with 1024 hidden units. In contrast to (Conneau et al., 2016), we also use layer normalization (Lei Ba et al., 2016) after fullyconnected layers as we observe this leads to more stable training behavior. Keys and queries are mapped from the output using a fullyconnected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =64 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, initial learning rate is chosen as 0.0008 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 212k iterations. For prototypical learning model with softmax attention, the initial learning rate is chosen as 0.0008 and exponential decay is applied with a rate of 0.9 applied every 8k iterations. The model is trained for 146k iterations. For prototypical learning model with sparsemax attention, the initial learning rate is chosen as 0.0005 and exponential decay is applied with a rate of 0.82 applied every 8k iterations. The model is trained for 270k iterations. All models use a batch size of 128 and gradient clipping above 20. We do not apply any data augmentation.
a.3 Tabular data
a.3.1 Adult Census Income
There are two output classes: whether or not the annual income is above $50k. Categorical categories such as the ‘maritalstatus’ are mapped to multihot representations. Continuous variables are used after a fixed normalization transformation. For ‘age’, the transformation first subtracts 50 and then divides by 30. For ‘fnlwgt’, the transformation first takes the log, and then subtracts 9, and then divides by 3. For ‘educationnum’, the transformation first subtracts 6 and then divides by 6. For ‘hoursperweek’, the transformation first subtracts 50 and then divides by 50. For ‘capitalgain’ and ‘capitalloss’, the normalization takes the log, and then subtracts 5, and then divides by 5. The concatenated features are then mapped to a 64 dimensional vector using a fullyconnected layer, followed by the ReLU nonlinearity. The base encoder uses an LSTM architecture, with 4 timesteps. At each timestep, 64dimensional inputs are applied after a dropout with rate 0.5. The output of the last timestep is used after applying a dropout with rate 0.5. Keys and queries are mapped from this output using a fullyconnected layer followed by the ReLU nonlinearity, where the attention size is =16. Values are mapped from the output using a fullyconnected layer of =16 units and the ReLU nonlinearity, followed by layer normalization. For the baseline encoder, the initial learning rate is chosen as 0.002 and exponential decay is applied with a rate of 0.9 applied every 2k iterations. The model is trained for 4.5k iterations. For the models with attention in prototypical learning framework, the initial learning rate is chosen as 0.0005 and exponential decay is applied with a rate of 0.92 applied every 2k iterations. The softmax attention model is trained for 13.5k iterations and the sparsemax attention model is trained for 11.5k iterations. For the model with sparsity regularization, the initial learning rate is 0.003 and exponential decay is applied with a rate of 0.7 applied every 2k iterations, and the model is trained for 7k iterations. All models use a batch size of 128 and gradient clipping above 20. We do not apply any data augmentation.
Appendix B Additional prototype examples
Fig. 10 exemplify prototypes for CIFAR10. For most cases, we observe the similarity of discriminative features between inputs and prototypes. For example, the body figures of birds, the shape of tires, the face patterns of dogs, the body figures of frogs, the appearance of the background sky for planes, are among these features apparent in examples.
Fig. 11 shows additional prototype examples for DBPedia dataset. Prototypes have very similar sentence structure, words and concepts, while categorizing the sentences into ontologies.
Fig. 12 shows example prototypes for ISIC Melanoma. In some cases, we observe the commonalities between input and prototypes that distinguish melanoma cases such as the noncircular geometry or irregularlynotched borders (Jerant et al., 2000). Compared to other datasets, ISIC Melonama dataset yields lower interpretable prototype quality on average. We hypothesize this to be due to the perceptual difficulty of the problem as well as the insufficient encoder performance shown by the lower classification accuracy (despite the acceptable AUC).
Fig. 13 shows more comparison examples for prototypical learning framework with sparsemax attention vs. representer point selection (Yeh et al., 2018) on Animals with Attributes dataset. For some cases, including chimpanzee, zebra, dalmatian and tiger, ProtoAttend yields perceptually very similar samples. The similarity of the chimpanzee body form and the background, zebra patterns, dalmatian pattern on the grass, and tiger pattern and head pose, are prominent. Representer point selection fails to capture such similarity features as effectively. On the other hand, for bat, otter and wolf, the results are somewhat less satisfying. The wing part of the bat, multiple count of the otters with the background, and the color and furry head of the wolf seem to be captured, but with less apparent similarity than some other possible samples from the dataset. Representer point selection method also cannot be claimed to be successful in these cases. Lastly, for leopard, ProtoAttend only yields one nonzero prototype (which is indeed statistically rare given the model and sparsity choices). The pattern of the leopard image seems relevant, but it is also not fully satisfying to observe a single prototype that is not perceptually more similar. All of the test examples in Fig. 13 are classified correctly with our framework and all of the shown prototypes are also from the correct classes.
Appendix C Comparison of confidencecontrolled prediction for softmax vs. sparsemax
Figs. 14 and 15 show the accuracy vs. ratio of samples for softmax vs. sparsemax attention without confidence regularization. The baseline accuracy (at 100% prediction ratio) is higher for softmax attention for some datasets, whereas higher for sparsemax for some others. On the other hand, higher number of prototypes yielded by softmax attention results in a wider range for confidencecontrolled prediction tradeoff.
As an impactful case study, we consider melanoma detection problem with ISIC dataset (ISIC, 2016) in Supplementary Material. In medical diagnosis, it is strongly desired to maintain a sufficientlyhigh prediction performance, potentially by verifying the decisions of an AI systems by medical experts in the cases where the AI models are not confident. By refraining from some predictions, as shown in Fig. 16, we demonstrate unprecedentedly high AUC values without using transfer learning or highlycustomized models (Haenssle et al., 2018).
Appendix D Controlling sparsity via regularization
Fig. 17 shows the impact of sparsity regularization coefficient on training. By varying the value of , the number of prototypes can be efficiently controlled. For high values of sparsity regularization coefficient, the model gets stuck at a point where it is forced to make decision from a low number of prototypes before the encoder model is properly learned, hence typically yields considerably lower performance. We also observe sparsity mechanism via sparsemax attention to yield better performance than softmax attention with high sparsity regularization.
Appendix E Prototype quality
In general, the following scenarios may yield low prototype quality:

[noitemsep,nolistsep]

Lack of related samples in the candidate database.

Perceptual difference between humans and encoders in determining discriminative features.

High intraclass variability that makes training difficult.

Imperfect encoder that cannot yield fully accurate representations of the input.

Insufficiency of relational attention to determine weights from queries and keys.

Inefficient decoupling between encoder & attention blocks and the final decision block.
There can be problemdependent fundamental limitations on (1)(3), whereas (4)(6) are raised by choices of models and losses and can be further improved. We leave the quantification of prototype quality using informationtheoretic metrics or discriminative neural networks to future work.
Appendix F Understanding misclassification cases
One of the benefits of prototypical learning is insights into wrong decision cases. Fig. 18 exemplifies prototypes with wrong labels, that give insights about why the model is confused about a particular input (e.g. due to similarity of the visual patterns). Such insights can be actionable to improve the model performance, such as adding more training samples for the confusing classes or modifying the loss functions.