Stochastic Prototype Embeddings

Stochastic Prototype Embeddings

Tyler R. Scott
University of Colorado, Boulder & Sensory Inc.
&Karl Ridgeway
University of Colorado, Boulder
&Michael C. Mozer
Google Research & University of Colorado, Boulder

Supervised deep-embedding methods project inputs of a domain to a representational space in which same-class instances lie near one another and different-class instances lie far apart. We propose a probabilistic method that treats embeddings as random variables. Extending a state-of-the-art deterministic method, Prototypical Networks (Snell et al., 2017), our approach supposes the existence of a class prototype around which class instances are Gaussian distributed. The prototype posterior is a product distribution over labeled instances, and query instances are classified by marginalizing relative prototype proximity over embedding uncertainty. We describe an efficient sampler for approximate inference that allows us to train the model at roughly the same space and time cost as its deterministic sibling. Incorporating uncertainty improves performance on few-shot learning and gracefully handles label noise and out-of-distribution inputs. Compared to the state-of-the-art stochastic method, Hedged Instance Embeddings (Oh et al., 2019), we achieve superior large- and open-set classification accuracy. Our method also aligns class-discriminating features with the axes of the embedding space, yielding an interpretable, disentangled representation.



1 Introduction

Supervised deep-embedding methods map instances from an input space to a latent embedding space in which same-label pairs are near and different-label pairs are far. The embedding thus captures semantic relationships without discarding inter-class structure. In contrast, consider a standard neural network classifier with a softmax output layer trained with a cross-entropy loss. Although its penultimate layer might be treated as an embedding, the classifier’s training objective attempts to orthogonalize all classes and thereby eliminate any information about inter-class structure.

Nearly all methods previously proposed for deep embeddings are deterministic: an instance projects to a single point in the embedding space. Deterministic embeddings fail to capture uncertainty due either to out-of-distribution inputs (e.g., data corruption) or label ambiguity (e.g., overlapping classes). Representing uncertainty is important for many reasons, including robust classification and decision making, informing downstream models, interpreting representations, and detecting out-of-distribution samples. In this article, we propose a method for discovering stochastic embeddings, where each embedded instance is a random variable whose distribution reflects the uncertainty in the embedding space.

Our proposed method, the Stochastic Prototype Embedding (SPE), is an extension of the Prototypical Network (PN) (Snell et al., 2017). As in the PN, our SPE assumes each class can be characterized by a prototype in the embedding space and an instance is classified based on its proximity to a prototype. In the case of the SPE, the embeddings and prototypes are Gaussian random variables, each class instance is assumed to be a Gaussian perturbation of the prototype, and a query instance is classified by marginalizing out over the embedding uncertainty. Using a synthetic data set, we demonstrate that the embedding uncertainty is related to both input and label noise. On a few-shot learning task, we show that the SPE significantly outperforms its state-of-the-art deterministic sibling, the PN. And on a challenging classification task, we find that the SPE outperforms Hedged Instance Embeddings (HIB) (Oh et al., 2019), the state-of-the-art stochastic embedding method.

2 Related Work

Supervised embedding methods are popular in the few-shot learning literature (Koch et al., 2015; Vinyals et al., 2016; Snell et al., 2017; Triantafillou et al., 2017; Finn et al., 2017; Edwards and Storkey, 2017; Scott et al., 2018; Ridgeway and Mozer, 2018; Mishra et al., 2018) where the goal is to classify query instances based on one or a small number of labeled exemplars of novel classes. These methods operate by embedding the queries and exemplars using a pre-trained network, and classifying each query according to its proximity to the exemplars. Embedding methods are also critical in open-set recognition domains such as face recognition and person re-identification (Chopra et al., 2005; Li et al., 2014; Yi et al., 2014; Zheng et al., 2015; Schroff et al., 2015; Liu et al., 2015; Ustinova and Lempitsky, 2016; Song et al., 2016; Wang et al., 2017).

Loss functions used to obtain embeddings can be characterized according to the number of instances required to specify a loss. To describe these losses, we will use the notation for an embedding of class . Pairwise losses attempt to minimize within-class distances, , and maximize between-class distances, (Chopra et al., 2005; Hadsell et al., 2006; Yi et al., 2014). Triplet losses attempt to ensure within-class instances are closer than between-class instances, (Schroff et al., 2015; Song et al., 2016; Wang et al., 2017). Quadruplet losses attempt to ensure every within-class pair is closer than every between-class pair, (Ustinova and Lempitsky, 2016). Finally, cluster-based losses attempt to use all instances of a class (Rippel et al., 2016; Fort, 2017; Song et al., 2017; Snell et al., 2017; Ridgeway and Mozer, 2018). In particular, the Prototypical Network (Snell et al., 2017) computes the mean of a set of instances of a class, , and ensures that additional instances of that class, , satisfy a proximity constraint such as . Cluster-based methods represent state-of-the-art over, in particular, pairwise and triplet losses, as one might expect given the chronology of publications.

Recently, probabilistic embedding methods have begun to appear. Allen et al. (2019) extend PNs via Bayesian nonparametric methods that treat each prototype as a mixture distribution, though they do not explore uncertainty in the embedding space nor leverage the embedding to handle noisy inputs and noisy labels, which is a significant aspect of our work. Vilnis and McCallum (2018) propose an unsupervised method for learning density-based word embeddings, where each embedding is represented by a Gaussian distribution; however this work is not comparable to our supervised method. Deep Variational Transfer (Belhaj et al., 2018) is a generative form of the discriminative model we propose; this work has the drawback that it needs to model the input distribution. Authors of this work used their approach for covariate shift, a somewhat different problem than we tackle.

Two prior methods have been proposed for discovering stochastic embeddings in a supervised setting, i.e., for few-shot and open-set recognition. The Hedged Instance Embedding (HIB) (Oh et al., 2019) utilizes a probabilistic alternative to the contrastive loss and is trained using a variational approximation to the information bottleneck principle. HIB is critically dependent on a constant, , that determines characteristics of an information bottleneck (i.e., how much of the input entropy is retained in the embedding). Choosing this constant is a matter of art. The Oracle-Prioritized Belief Network (OPBN) (Karaletsos et al., 2016) is a generative model that learns a joint distribution over inputs and oracle-provided triplet constraints. The OPBN was not tested on few-shot and open-set recognition because it requires extensions to be applied to classification tasks. In the deterministic setting, Scott et al. (2018) argue that cluster-based methods outperform pairwise and triplet methods; thus, we have reason to expect that in a stochastic setting, a cluster-based method like the one we propose in this article, SPE, will outperform pairwise (HIB) and triplet (OPBN) methods.

3 The Model

The SPE assumes that the latent representation, , is a Gaussian RV conditioned on the input, :


with mean, , and variance, , computed by a deep neural network, similar to a Variational Autoencoder (Kingma and Welling, 2014). The classification, , in turn is conditioned on , with taking the same form as in the original PN (Snell et al., 2017), to be described shortly. Given an input, a class prediction is made by marginalizing over the embedding uncertainty:

Figure 1: (a) Illustration of the stochastic prototype embedding. The model learns a mapping from input space, , to embedding space, , in which same-class instances are near and different-class instances are far. Embeddings are represented as Gaussian random variables. Prototypes, noted as symbols in the embedding, are formed via a confidence-weighted average of the embeddings of instances known to belong to a class (support instances). Prototype uncertainty is depicted with the dotted ovals. Given the prototypes, a prediction of class is made for a query instance by marginalizing a softmax prediction over the embedding space. (b) Depiction of intersection sampler.

Figure 1a depicts the relationship between the input, latent, and class representations. We train the SPE using the standard few-shot learning paradigm, consisting of a sequence of episodes, each with instances of classes. We split the instances into support examples, defining a set , and query examples. The support instances for each class , , are used to determine the class prototype, , and the query instances are evaluated to predict class label (Equation 2).

3.1 Forming class prototypes

In the SPE, each class has an associated prototype, , in the embedding space, and each instance of class , denoted , projects to an embedding, , in the neighborhood of such that:


We assume that the prototype is consistent with all support instances, allowing us to express the likelihood of as a product distribution:


Because is Gaussian, the resulting product is too:


where and denotes the Hadamard product. Essentially, the prototype is a confidence-weighted average of the support instances. This formulation has a clear advantage over the deterministic PN, which is premised on an unweighted average, because it de-emphasizes noisy support instances.

3.2 Prediction and approximate inference

We assume a softmax prediction for a query embedding, :


with as before, yielding the class posterior for query :


The class distribution is equivalent to that produced by the deterministic PN as when for all class pairs . However, in the general case, the integral has no closed form solution; thus, we must sample to approximate , both for training and evaluation. We employ two samplers, which we refer to as naïve and intersection.

3.2.1 Naïve sampling

A direct approach to approximating the class posterior is to express Equation 2 as an expectation, , and to replace the expectation with the average over a set of samples. We utilize the reparameterization trick of Kingma and Welling (2014) to train the model. Although this is the simplest approach, it is sample-inefficient during training, and when the number of samples is reduced, model performance is impacted.

3.2.2 Intersection sampling

In Equation 7, the product of Gaussian densities in the numerator can be rewritten:


where and  Substituting Equation 8 into Equation 7,


By approximating the expectation with samples from , we obtain a sampler that focuses on the intersection of the input distribution and a given class distribution, as illustrated in Figure 1b. During training with a cross-entropy loss, we need only sample for the known (target) class . As we will demonstrate, this method is more robust and significantly more sample efficient than the naïve sampler, requiring only a single sample to train effectively.

4 Experimental Results

We report on three sets of experiments. In Section 4.1, we demonstrate, using a synthetic data set, that SPE infers the generative structure of a domain, disentangles class-discriminating features, and provides meaningful estimates of label uncertainty and input noise. In Section 4.2, we show that SPE obtains state-of-the-art results on few-shot learning via a comparison to its deterministic sibling, PN, the previous state-of-the-art method. We evaluate on a standard data set used to compare methods in the few-shot learning literature, Omniglot (Lake et al., 2015). In Section 4.3, we show that SPE obtains state-of-the-art results on large-set classification via a comparison to the only other fully developed stochastic method for supervised embeddings, HIB (Oh et al., 2019). We evaluate on the only data set that Oh et al. (2019) used to explore HIB, a multi-digit variant of MNIST. For details regarding network architectures and hyperparameters, see Appendix A, and for simulation details, including the choice of initialization for , see Appendix B.

4.1 Synthetic color-orientation data set

The data set consists of pixel images of ‘L’ shapes, with four classes that are distinguished by orientation, color, or both (Figure 2a). Instances are sampled from a class-conditional isotropic Gaussian distribution in the generative space. (The isotropy of these qualitatively different dimensions comes from the fact that both can be mapped as directional quantities.) Because classes overlap on both color and orientation dimensions, elicited embeddings should indicate increased uncertainty near class boundaries. Full details of the synthetic data set can be found in Appendix A.2.

We trained a two-dimensional, intersection-sampling SPE on samples from this domain, using two instances per class to form prototypes. Classification accuracy of held-out samples is approximately . Accounting for class overlap, a Bayes optimal classifier has an accuracy of approximately . For visualization, Figure 2b presents a array of examples with the class centroids in the corners and the other examples obtained by linear interpolation in the generative space. The resulting embeddings are presented in Figure 2c. Although the correspondence between Figures 2b and  2c seems trivial (mirror one set along the horizontal axis to obtain the other set), remember that the input space is dimensional and the latent space is dimensional. The network has captured the structure of the domain by disentangling the two factors of variation. Further, the embedding variance encodes label ambiguity; instances halfway between two classes on one dimension have maximal variance along that dimension. Label ambiguity is one type of uncertainty. An equally important source of uncertainty comes from noisy or out-of-distribution (OOD) inputs. We examined OOD inputs generated in two different ways. In the left panel of Figure 3, we show the consequence of adding pixel hue noise to the four class centroids. Only one of these centroids is shown along the abscissa, but all four are used to make the graph, with many samples per noise level. The grey and black bars in the graph indicate variance on the horizontal and vertical dimensions of the embedding space, respectively. As pixel hue noise increases, uncertainty in color grows but uncertainty in orientation does not. In the right panel of Figure 3, we show the consequence of shortening the leg-length of the shape. Shortening the legs removes cues that can be used both for determining color and orientation. As a result, the uncertainty grows on both dimensions.

Figure 2: (a) Samples from the four classes in our synthetic data set. In each plot, class centroids are circled, along with samples spanning standard deviations in both orientation and color. A sample’s transparency is set according to its class-conditional likelihood. Both dimensions can be coded as directional variables. The class centroids on each dimension are apart with standard deviation of . (b) A set of examples, with the four class centroids located in the corners and other examples obtained by linear interpolation in the generative space. (c) The D stochastic prototype embedding for the examples in (b). The shape is plotted at the mean of , and the outlines of the ovals represent equiprobability contours at standard deviations.
Figure 3: Synthetic data set: uncertainty on the two embedding dimensions as it becomes more difficult to discern the hue (left) and orientation (right).

4.2 Omniglot

The Omniglot data set contains images of labeled, handwritten characters from diverse alphabets. Omniglot is one of the standard data sets for comparing methods in the few-shot learning literature. The data set contains unique characters, each with instances. Following Snell et al. (2017), each grayscale image is resized from to , and we augment the original classes with all rotations, resulting in total classes. We train PNs and SPEs episodically, where a training episode contains randomly sampled classes and query instances per class.

Figure 4: Test classification accuracy as a function of number of training samples per query instance for a naïve-sampling and intersection-sampling D SPE on a -shot, -class Omniglot task. Performance is a mean over replications of running the model, showing standard error of the mean.
Figure 5: Two-dimensional embedding learned by the SPE on the Omniglot test set. Each square thumbnail image in the figure is a randomly-sampled instance from one of randomly-sampled test classes and the location of the image represents the location of the class prototype. The images have a gray bounding box for visualization purposes only.

To compare the relative effectiveness of naïve and intersection samplers, we train the SPE on Omniglot varying both the sampler and the number of samples drawn per training query, denoted by . We evaluate in a -shot -class setting, where shot refers to the number of support examples used to compute each prototype. Figure 4 shows test classification accuracy as the number of samples drawn per training trial () increases. As we previously stated, the intersection-sampling SPE is far more sample efficient, to the point that the intersection sampler with outperforms the naïve sampler with . We have verified that the pattern in Figure 4 is consistent across simulations; consequently, we present only intersection-sampling SPE results in the remainder of the article, and all SPEs are trained with a single sample () per query. This choice causes the SPE to be on par with the PN in time and space requirements, even though using more samples may boost classification accuracy, as suggested by the trend in Figure 4.

Figure 5 is a visualization of a D embedding learned by the intersection-sampling SPE on Omniglot. All classes shown in the figure were held-out during training. Omniglot characters clearly vary along more than two dimensions, so a D SPE cannot learn a fully-disentangled representation as it did with the synthetic data set. However, we can still interpret the axes of the embedding. The horizontal axis appears to represent character complexity, with single-stroke characters on the left and many-stroke characters on the right. The vertical axis appears to encode the aspect ratio of the characters, with horizontally extended characters on the bottom and vertically extended characters on the top.

Figure 6a compares the PN and SPE with D embeddings on Omniglot test classes. Each bar is the mean accuracy across four conditions: -shot/-class, -shot/-class, -shot/-class, and -shot/-class. The first pair of bars perform the standard comparison in which the (1 or 5 instance) support set is used to obtain an embedding for each class, prototypes are formed, and query instances are classified. SPE is reliably better than the PN. Because the Omniglot data are carefully curated, the instances have little noise and therefore offer little opportunity to leverage SPE’s assessment of uncertainty. Consequently, we corrupted instances by masking out rectangular regions of the input, as proposed by Oh et al. (2019). (See Appendix E for details.) The second and third sets of bars in Figure 6a correspond to the situations where the support and query instances are corrupted, respectively. SPE’s advantage over PN increases significantly when the support instances are corrupted due to the fact that SPE’s confidence-weighted prototypes (Equation 5) discount noisier support examples. Although the SPE is still superior when only the query is corrupted, the benefit is small. We also compared PN and SPE using a D embedding, but with high dimensional embeddings, both methods are near ceiling on this data set, resulting in comparable performance between the two methods. (See Appendix D for additional results, broken down by condition.)

To emphasize, SPE outperforms the PN, arguably the leading few-shot learning method, especially when inputs are corrupted, at essentially the same computational cost for training. And by providing an estimate of uncertainty associated with embedded instances, the SPE offers the possibility of detecting OOD samples and informing downstream systems that operate on the embedding.

4.3 N-digit MNIST

The -digit MNIST data set was proposed to evaluate HIB (Oh et al., 2019); it is formed by horizontal concatenation of MNIST digit images. The resulting images are . To compare with HIB, we study - and -digit MNIST, and use a network architecture identical to that in Oh et al. (2019). Oh et al. (2019) split the data into a training set (with of the total classes), a seen test set, and an unseen test set. For -digit MNIST, the seen test set has the same of classes as the training set and the unseen test set has the remaining classes. For -digit MNIST, the training set has classes, the seen and unseen test sets each have a sample of of the seen or unseen classes, respectively. We use the same train and test data splits as Oh et al. (2019), but we further divide the training split to include a validation set for early stopping.

Figure 6: (a) Comparison of few-shot accuracy on Omniglot test classes for the PN (Snell et al., 2017) and our SPE. (b) Comparison of test accuracy on seen classes for and -digit MNIST for HIB (Oh et al., 2019) and our SPE. (c) Same as (b) except for unseen classes. In (a)-(c), error bars reflect standard error of the mean, corrected to remove cross-condition variance (Masson and Loftus, 2003).
Figure 7: Two-dimensional embedding learned by the SPE on the -digit MNIST test set. A class is specified by a two-digit number. In both figures, the location of the class corresponds to the mean of the prototype in the test set using support instances. The digits surrounded by a black border are classes that were not seen during training. In the left and right figures, the prototypes are colored according to the first and second digit of the class, respectively.

Figure 7 shows two views of the D embedding learned by the SPE on the -digit MNIST test set. Each number is a class label; for example, , located in the lower left of the embedding, is the class in which the first of the two MNIST digits is a and the second is a . The location of a label in the space corresponds to the mean of its prototype. In the left plot, each class is colored according to the first digit. The right plot is the same embedding, but each prototype is colored according to the second digit. The SPE learns an incredibly robust factorial representation in which the horizontal dimension represents the first digit of a class and the vertical dimension represents the second digit. A black bounding box indicates the unseen test classes, classes not presented during training. Impressively, the unseen test classes are embedded in exactly the positions where they belong, indicating that the SPE can discover relationships among classes that allow it to generalize to classes it has never seen during training. Furthermore, the embedding has captured inter-class similarity structure by placing visually similar digits close to one another. For example, on both the vertical and horizontal bands, nines (teal) and fours (purple) are adjacent, and fives (brown) and threes (red) are adjacent. The adjacency relationships vary a bit from one dimension of the mapping to the other; for example, sixes (pink) are adjacent to eights (yellow) and zeros (blue) in the vertical bands, but adjacent to fives (brown) and zeros in the horizontal bands. HIB is able to discover a similar structure along one dimension (Oh et al., 2019), but the second dimension is somewhat more entangled, suggesting that the SPE learns a more robust representation. Additionally, embeddings for the unseen class are not presented for HIB. The ability to sensibly embed novel classes is essential for any model that will be used for open-set recognition or few-shot learning.

Figure 6b,c compare -digit MNIST test accuracy on seen and unseen classes, respectively.111HIB results are from Oh et al. (2019). We thank the authors for providing us results on unseen classes, which were not included in their publication. Each bar is the mean test accuracy across the Cartesian product of conditions specified by the number of MNIST digits in each image, , and the dimensionality of the embedding, . As in the Omniglot simulation, we varied whether support and query instances were clean or corrupted. The SPE outperforms HIB in all six comparisons. In the individual conditions, SPE is worse on only . As in the Omniglot simulation, SPE shines best when support instances may be corrupted. (Appendix A.3 provides tabular results by condition, not only for HIB and SPE, but also their deterministic counterparts, contrastive loss and PN. Because the deterministic methods perform consistently worse than the stochastic methods, we omit the deterministic methods from the figure.)

Whereas SPE is a discriminative model with a specified classification procedure, Oh et al. (2019) had the freedom to design one. They use all available data—roughly examples per class—and perform leave-one-out -nearest-neighbor classification. To be consistent with our episodic test procedure, the SPE uses only support instances per class to form prototypes. It is particularly impressive that the SPE, based on a single stored prototype and approximately the labeled data, can outperform a memory-based nonparametric method that is able to model arbitrary distributions in latent space.

5 Discussion and Conclusions

Our Stochastic Prototype Embedding (SPE) method outperforms a state-of-the-art deterministic method, the Prototypical Net (PN), on few-shot learning, particularly when support instances may be corrupted. Because the SPE reduces to the PN under certain restrictions, it seems unlikely to fare worse; but because it can handle uncertainty in both the query and support set, it has great opportunity to improve on the PN. Many extensions have been proposed to the PN (e.g., Fort, 2017; Allen et al., 2019). These extensions are mostly compatible with ours, and thus methods may be potentially combined to attain even stronger few-shot learning performance under uncertainty.

SPE also significantly outperforms the only existing alternative stochastic method, the Hedged Instance Embedding (HIB), on a the complete battery of large-set classification tasks used to evaluate HIB. Beyond its performance gains, SPE has no hand tuned parameters, whereas HIB has constant that determines characteristics of an information bottleneck (i.e., how much of the input entropy is retained in the embedding). Although one could simply set , doing so would encourage the net to perform like a softmax classifier and discard all information about inter-class similarity. Such similarities are essential in order to generalize to unseen classes (e.g., Figure 7).

We proposed and evaluated an intersection sampler to train the SPE, which makes the SPE as time and space efficient for training as the deterministic PN, and more efficient for training than HIB, which relies on about samples per item. (Our evaluation method for SPE presently involves drawing samples from the naive sampler, though this conservative decision was arbitrary and not tuned.)

An unanticipated virtue of SPE is its ability to obtain interpretable, disentangled representations (Figures 2, 5, 7). Because uncertainty is encoded in a diagonal covariance matrix, any classification ambiguity maps to uncertainty in the value of individual features of the embedding. Thus, class-discriminating feature dimensions must align with the principle axes of the embedding space. In contrast to traditional unsupervised disentangling methods, which aim to discover the underlying generative factors of a domain, the SPE obtains a supervised analog in which the underlying class-discriminative factors are represented explicitly. This representation facilitates generalization to novel unseen classes and is therefore valuable for few-shot and lifelong learning paradigms.


  • Allen et al. (2019) Allen, K. R., Shelhamer, E., Shin, H., and Tenenbaum, J. B. (2019). Infinite Mixture Prototypes for Few-Shot Learning. arXiv e-prints 1902.04552 cs.LG.
  • Belhaj et al. (2018) Belhaj, M., Protopapas, P., and Pan, W. (2018). Deep Variational Transfer: Transfer Learning through Semi-supervised Deep Generative Models. arXiv e-prints 1812.03123 cs.LG.
  • Chopra et al. (2005) Chopra, S., Hadsell, R., and LeCun, Y. (2005). Learning a Similarity Metric Discriminatively, with Application to Face Verification. In IEEE Conference on Computer Vision and Pattern Recognition.
  • Edwards and Storkey (2017) Edwards, H. and Storkey, A. (2017). Towards a Neural Statistician. In International Conference on Learning Representations.
  • Finn et al. (2017) Finn, C., Abbeel, P., and Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. In International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 1126–1135.
  • Fort (2017) Fort, S. (2017). Gaussian Prototypical Networks for Few-Shot Learning on Omniglot. In Second Workshop on Bayesian Deep Learning (NIPS 2017).
  • Hadsell et al. (2006) Hadsell, R., Chopra, S., and LeCun, Y. (2006). Dimensionality Reduction by Learning an Invariant Mapping. In IEEE Computer Society Conference on Computer Vision and Pattern Recognition, volume 2, pages 1735–1742.
  • Karaletsos et al. (2016) Karaletsos, T., Belongie, S., and Rätsch, G. (2016). Bayesian Representation Learning with Oracle Constraints. In International Conference on Learning Representations.
  • Kingma and Welling (2014) Kingma, D. P. and Welling, M. (2014). Auto-Encoding Variational Bayes. In International Conference on Learning Representations.
  • Koch et al. (2015) Koch, G., Zemel, R., and Salakhutdinov, R. (2015). Siamese Neural Networks for One-Shot Image Recognition. In ICML Deep Learning Workshop, volume 2.
  • Lake et al. (2015) Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-Level Concept Learning through Probabilistic Program Induction. Science, 350(6266):1332–1338.
  • Li et al. (2014) Li, W., Zhao, R., Xiao, T., and Wang, X. (2014). DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. In IEEE Conference on Computer Vision and Pattern Recognition.
  • Liu et al. (2015) Liu, J., Deng, Y., Bai, T., Wei, Z., and Huang, C. (2015). Targeting Ultimate Accuracy: Face Recognition via Deep Embedding. arXiv e-prints 1506.07310 cs.CV.
  • Masson and Loftus (2003) Masson, M. E. J. and Loftus, G. R. (2003). Using Confidence Intervals for Graphically Based Data Interpretation. In Canadian Journal of Experimental Psychology/Revue canadienne de psychologie expérimentale, volume 57, pages 203–220.
  • Mishra et al. (2018) Mishra, N., Rohaninejad, M., Chen, X., and Abbeel, P. (2018). A Simple Neural Attentive Meta-Learner. In International Conference on Learning Representations.
  • Oh et al. (2019) Oh, S. J., Gallagher, A. C., Murphy, K. P., Schroff, F., Pan, J., and Roth, J. (2019). Modeling Uncertainty with Hedged Instance Embeddings. In International Conference on Learning Representations.
  • Ridgeway and Mozer (2018) Ridgeway, K. and Mozer, M. C. (2018). Learning Deep Disentangled Embeddings With the F-Statistic Loss. In Bengio, S., Wallach, H., Larochelle, H., Grauman, K., Cesa-Bianchi, N., and Garnett, R., editors, Advances in Neural Information Processing Systems 31, pages 185–194. Curran Associates, Inc.
  • Rippel et al. (2016) Rippel, O., Paluri, M., Dollar, P., and Bourdev, L. (2016). Metric Learning with Adaptive Density Discrimination. In International Conference on Learning Representations.
  • Schroff et al. (2015) Schroff, F., Kalenichenko, D., and Philbin, J. (2015). FaceNet: A Unified Embedding for Face Recognition and Clustering. In IEEE Conference on Computer Vision and Pattern Recognition.
  • Scott et al. (2018) Scott, T., Ridgeway, K., and Mozer, M. C. (2018). Adapted Deep Embeddings: A Synthesis of Methods for k-Shot Inductive Transfer Learning. In Bengio, S., Wallach, H., Larochelle, H., Grauman, K., Cesa-Bianchi, N., and Garnett, R., editors, Advances in Neural Information Processing Systems 31, pages 76–85. Curran Associates, Inc.
  • Snell et al. (2017) Snell, J., Swersky, K., and Zemel, R. (2017). Prototypical Networks for Few-shot Learning. In Advances in Neural Information Processing Systems 31, pages 4077–4087.
  • Song et al. (2017) Song, H. O., Jegelka, S., Rathod, V., and Murphy, K. (2017). Deep Metric Learning via Facility Location. In IEEE Conference on Computer Vision and Pattern Recognition, pages 2206–2214.
  • Song et al. (2016) Song, H. O., Xiang, Y., Jegelka, S., and Savarese, S. (2016). Deep Metric Learning via Lifted Structured Feature Embedding. In IEEE Conference on Computer Vision and Pattern Recognition.
  • Triantafillou et al. (2017) Triantafillou, E., Zemel, R., and Urtasun, R. (2017). Few-Shot Learning Through an Information Retrieval Lens. In Advances in Neural Information Processing Systems 31, pages 2255–2265.
  • Ustinova and Lempitsky (2016) Ustinova, E. and Lempitsky, V. (2016). Learning Deep Embeddings with Histogram Loss. In Advances in Neural Information Processing Systems 30, pages 4170–4178.
  • Vilnis and McCallum (2018) Vilnis, L. and McCallum, A. (2018). Word Representations via Gaussian Embedding. In International Conference on Learning Representations.
  • Vinyals et al. (2016) Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., and Wierstra, D. (2016). Matching Networks for One Shot Learning. In Advances in Neural Information Processing Systems 30, pages 3630–3638.
  • Wang et al. (2017) Wang, J., Zhou, F., Wen, S., Liu, X., and Lin, Y. (2017). Deep Metric Learning with Angular Loss. In IEEE International Conference on Computer Vision, pages 2612–2620.
  • Yi et al. (2014) Yi, D., Lei, Z., Liao, S., and Li, S. Z. (2014). Deep Metric Learning for Person Re-identification. In International Conference on Pattern Recognition, pages 34–39.
  • Zheng et al. (2015) Zheng, L., Shen, L., Tian, L., Wang, S., Wang, J., and Tian, Q. (2015). Scalable Person Re-identification: A Benchmark. In IEEE International Conference on Computer Vision.

Appendix A Network Architectures and Hyperparameters

a.1 Omniglot

For all Omniglot experiments, the network consisted of four convolutional blocks. The first three blocks had a convolutional layer with filters, a kernel, zero-padding of length , and a stride of , followed by a batch normalization layer, ReLU activation, and max-pooling. The fourth and final block had a convolutional layer with filters, a kernel, zero-padding of length , and a stride of , followed by max-pooling, where represents the dimensionality of the embedding space. The flattened output of the network is a vector of length , where the first elements were considered the mean of the Gaussian distribution and the remaining elements were the diagonal covariance entries. The weights were initialized using He initialization and the biases with the following uniform distribution: .

All Omniglot models were trained with an initial learning rate of which was cut in half every epochs. The models were stopped early using a patience parameter when performance on the validation set no longer increased.

a.2 Synthetic data

The images in the synthetic data set are pixels in size. For orientation, we chose class centers at and , with a standard deviation of . For color, we manipulated the hue and kept value and saturation constant. Like orientation, hue is a circular quantity. If hue ranges from to degrees, we chose color class centers and standard deviation in the same way as orientation. Additionally, we add noise to a minority (15%) of the images used to train the model. For these, we add Gaussian noise to the hue of each pixel inside the shape. The standard deviation of the hue noise was chosen uniformly between and . We also added noise to the leg lengths of the L shapes. The leg length was chosen uniformly between 10% and 98% of its original length. See Figure 3 for some examples.

The network followed an architecture similar to the one we used for Omniglot, except that we added two additional blocks of convolution, batch normalization, ReLU, and max-pooling because the images are larger. We used instances per class to form prototypes and samples per query instance during training. We used a learning rate of 0.0001 and the models were stopped early using a patience parameter when performance on the validation set no longer increased.

a.3 N-Digit MNIST

For all -digit MNIST experiments, we constructed an architecture which we believe to be identical to that used for HIB MNIST experiments, based on code provided by the authors (Oh et al., 2019). The network consisted of two convolutional blocks followed by two fully-connected layers. The convolutional blocks each contained a convolutional layer, followed by an ReLU activation, and max-pooling. The first convolutional layer had filters, a kernel, zero-padding of length , and a stride of . The second convolutional layer was identical to the first, but had filters instead of . The output of the second convolutional block was flattened, passed through a fully-connected layer with units, an ReLU activation, and a final fully-connected layer with units, where represents the dimensionality of the embedding space. Like the Omniglot architectures, the first entries in the output vector are treated as the mean and the remaining elements as the diagonal covariance entries. The weights were initialized using a Xavier uniform initialization and biases were initialized to zero.

The PN and SPE are trained episodically with all performance results in the main article measured as the mean over random test episodes. All -digit MNIST models were trained with an initial learning rate of which was cut in half every epochs. The models were stopped early using a patience parameter when performance on the validation set no longer increased. For -digit MNIST, each episode in training, validation, and seen-class testing contained all classes and support instances per class. For testing of unseen classes, each episode contained all classes. For -digit MNIST, each episode contained classes and either support instances per class for training/validation or support instances per class for seen- and unseen-class testing.

Appendix B Simulation Details

For all SPE models,

where is a trainable parameter. We initialize using the following prescription:

where is the number of support examples per episode during training and is the dimensionality of the embedding. We chose this prescription for two reasons: (1) as the number of support examples increases, the variance of the prototype distribution approaches zero, so scaling linearly by tends to provide a stronger training signal early on, and (2) the amount of noise in the projection of an embedding should scale with the dimensionality of the embedding space as to maintain unit-volume. All models used .

The variance of each dimension , , is guaranteed to be non-negative by using a softplus transfer function.

Whether trained with the naïve or intersection sampler, we evaluate model performance using the naïve sampler with samples. This approach ensures that we are comparing the quality of models based only on the method by which they were trained.

Appendix C SPE Variants

We assumed only diagonal covariance matrices in this work. Switching to a full covariance matrix would require matrix inversion, which is ordinarily infeasible, but because one purpose of deep embeddings is visualization, there may be interesting cases involving 2D embeddings where the cost of inversion is trivial. However, using a diagonal covariance matrix causes class-discriminating features to be aligned with the axes of the latent space, as we argued in the main article, and this alignment is a virtue for interpretation.

Appendix D Tabular Results

d.1 Omniglot

Clean Support, Clean Query

1-shot, 5-class 5-shot, 5-class 1-shot, 20-class 5-shot, 20-class


PN 75.7 82.6 45.0 55.9


SPE 76.9 82.3 49.7 55.3


Corrupt Support, Clean Query

1-shot, 5-class 5-shot, 5-class 1-shot, 20-class 5-shot, 20-class


PN 50.0 65.9 23.6 31.7


SPE 50.7 73.9 25.6 41.6


Clean Support, Corrupt Query

1-shot, 5-class 5-shot, 5-class 1-shot, 20-class 5-shot, 20-class


PN 48.9 52.3 21.7 25.6


SPE 47.8 52.3 22.8 26.8


Table 1: Test classification accuracy (%) on Omniglot with a D embedding for clean-support/clean-query, corrupt-support/clean-query, and clean-support/corrupt-query. PN is our implementation of Prototypical Networks (Snell et al., 2017). SPE is our model. SPE is trained with intersection sampling (1 sample per trial). Reported accuracy for each experimental configuration is the mean over random test episodes.

d.2 N-Digit MNIST

Clean Support, Clean Query

seen test classes unseen test classes
N=2 N=3 N=2 N=3
D=2 D=3 D=2 D=3


D=2 D=3 D=2 D=3


Contrastive 88.2 95.0 65.8 87.3


85.5 84.8 59.0 85.5


HIB 87.9 95.2 65.0 87.3


87.3 91.0 64.4 88.2


PN 91.1 95.0 65.8 90.6


82.0 89.5 64.3 89.1


SPE 93.0 94.2 80.2 89.0


90.0 89.3 80.2 88.2


Corrupt Support, Clean Query

seen test classes unseen test classes
N=2 N=3 N=2 N=3
D=2 D=3 D=2 D=3


D=2 D=3 D=2 D=3


Contrastive 76.2 92.2 49.5 77.6


76.5 73.3 42.6 73.2


HIB 81.6 94.3 54.0 81.2


80.8 86.7 53.9 81.2


PN 72.7 93.3 44.6 82.7


70.9 86.3 42.9 79.6


SPE 92.4 93.8 76.7 87.8


88.8 86.3 75.4 86.3


Clean Support, Corrupt Query

seen test classes unseen test classes
N=2 N=3 N=2 N=3
D=2 D=3 D=2 D=3


D=2 D=3 D=2 D=3


Contrastive 43.5 51.6 29.3 44.7


46.3 44.8 26.2 42.0


HIB 49.9 57.8 31.8 49.9


53.5 57.0 32.1 50.2


PN 53.1 61.1 33.8 56.4


51.1 57.9 33.0 54.8


SPE 53.7 58.2 40.2 48.1


56.3 56.5 39.3 46.6


Table 2: Test classification accuracy () on - and -digit MNIST for clean-support/clean-query, corrupt-support/clean-query, and clean-support/corrupt-query. : number of digits in each image; : dimensionality of the embedding. Contrastive and HIB results from Oh et al. (2019). PN is our implementation of Prototypical Networks (Snell et al., 2017). SPE is our model. SPE is trained with intersection sampling (1 sample per trial). Reported accuracy for PN and SPE for each experimental configuration is the mean over 1000 random test episodes.

Appendix E Corruption Procedure

Figure 8: Examples of occluded 2-digit sequences. Occlusion is based on random rectangles that black out portions of each digit.

The algorithm for applying corruption was identical to the scheme used in Oh et al. (2019). A random rectangular-sized occlusion of black pixels was determined by first sampling a patch width, , and patch height, , from a uniform distribution, , and then sampling the top-left corner coordinates, , . This resulted in an occlusion of area . Note that if or , the image was left unoccluded. Figure 8 shows examples of occluded 2-digit images.

For Omniglot, we only trained/validated on corrupted imagery if the test set contained a corrupted support or corrupted query set. When testing on clean support and clean query, the training and validation sets were left unoccluded. When testing on corrupted imagery, the training and validation sets corrupted each character independently with a probability of .

The training and validation sets for -digit MNIST corrupted each digit of each image independently with a probability of , regardless of test imagery. This matched Oh et al. (2019).

During testing on both data sets, we considered both clean and corrupt support sets, as well as clean and corrupt query sets. A clean set was one in which all digits/characters were unoccluded. A corrupt set occluded each digit/character in each image according to the procedure described above.

Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description