Sampling Matters in Deep Embedding Learning
Deep embeddings answer one simple question: How similar are two images? Learning these embeddings is the bedrock of verification, zero-shot learning, and visual search. The most prominent approaches optimize a deep convolutional network with a suitable loss function, such as contrastive loss or triplet loss. While a rich line of work focuses solely on the loss functions, we show in this paper that selecting training examples plays an equally important role. We propose distance weighted sampling, which selects more informative and stable examples than traditional approaches. In addition, we show that a simple margin based loss is sufficient to outperform all other loss functions. We evaluate our approach on the Stanford Online Products, CAR196, and the CUB200-2011 datasets for image retrieval and clustering, and on the LFW dataset for face verification. Our method achieves state-of-the-art performance on all of them.
Models that transform images into rich, semantic representations lie at the heart of modern computer vision, with applications ranging from zero-shot learning [bucher2016improving, yuan2016hard] and visual search [hadi2015buy, bell2015learning, song2016learnable, oh2016deep], to face recognition [facenet, chopra2005learning, parkhi2015deep, sohn2016improved] or fine-grained retrieval [song2016learnable, oh2016deep, sohn2016improved]. Deep networks trained to respect pairwise relationships have emerged as the most successful embedding models [bromley1993signature, chopra2005learning, facenet, ustinova2016learning].
The core idea of deep embedding learning is simple: pull similar images closer in embedding space and push dissimilar images apart. For example, the contrastive loss [hadsell2006dimensionality] forces all positives images to be close, while all negatives should be separated by a certain fixed distance. However, using the same fixed distance for all images can be quite restrictive, discouraging any distortions in the embedding space. This motivated the triplet loss, which only requires negative images to be farther away than any positive images on a per-example basis [facenet]. This triplet loss is currently among the best-performing losses on standard embedding tasks [zhuang2016fast, facenet, oh2016deep]. Unlike pairwise losses, the triplet loss does not just change the loss function in isolation, it changes the way positive and negative example are selected. This provides us with two knobs to turn: the loss and the sampling strategy. See \figrefteaser for an illustration.
In this paper, we show that sample selection in embedding learning plays an equal or more important role than the loss. For example, different sampling strategies lead to drastically different solutions for the same loss function. At the same time many different loss functions perform similarly under a good sampling strategy: A contrastive loss works almost as well as the triplet loss, if the two use the same sampling strategy. In this paper, we analyze existing sampling strategies, and show why they work and why not. We then propose a new sampling strategy, where samples are drawn uniformly according to their relative distance from one another. This corrects the bias induced by the geometry of embedding space, while at the same time ensuring any data point has a chance of being sampled. Our proposed sampling leads to a lower variance of gradients, and thus stabilizes training, resulting in a qualitatively better embedding irrespective of the loss function.
Loss functions obviously also matter. We propose a simple margin-based loss as an extension to the contrastive loss. It only encourages all positive samples to be within a distance of each other rather than being as close as possible. It relaxes the loss, making it more robust. In addition, by using isotonic regression, our margin based loss focuses on the relative orders instead of absolute distances.
Our margin based loss and distance weighted sampling achieve state-of-the-art image retrieval and clustering performance on the Stanford Online Products, CARS196, and the CUB200-2011 datasets. It also outperforms previous state-of-the-art results on the LFW face verification dataset [LFWTech] using standard publicly available training data. Both our loss function and sampling strategy are easy to implement and efficient to train.
2 Related Work
The idea of using neural networks to extract features that respect certain relationships dates back to the 90s. Siamese Networks [bromley1993signature] find an embedding space such that similar examples have similar embeddings and vice versa. Such networks are trained end-to-end, sharing weights between all mappings. Siamese Networks were first applied to signature verification, and later extended to face verification and dimensionality reduction [chopra2005learning, hadsell2006dimensionality]. However, given the limited compute power at the time and their non-convex nature, these approaches initially did not enjoy much attention. Convex approaches were much more popular [xing2002distance, comon1994independent]. For example, the triplet loss [schultz2003learning, weinberger2009distance] is one of the most prominent methods that emerged from convex optimization.
Given sufficient data and computational power both schools of thought were combined into a Siamese architecture using triplet losses. This leads to near human performance in face verification [facenet, parkhi2015deep]. Motivated by the triplet loss, some enforce constraints on even more examples. For example, PDDM [huang2016local] and Histogram Loss [ustinova2016learning] use quadruplets. Beyond that, the n-pair loss [sohn2016improved] and Lifted Structure [oh2016deep] defines constraints on all images in a batch.
This plethora of loss functions is quite reminiscent of the ranking problem in information retrieval. There a combination of individual, pair-wise [HerGraObe99b], and list-wise approaches [CoFiRank] are used to maximize relevance. Of note is isotonic regression which disentangles the pairwise comparisons for greater computational efficiency. See [MoonZhengSmolaEtAl] for an overview.
Some papers explore modeling of other properties. Structural Clustering [song2016learnable] optimizes for clustering quality. PDDM [huang2016local] proposes a new module to model local feature structure. HDC [yuan2016hard] trains an ensemble to model examples of different “hard levels”. In contrast, here we show that a simple pairwise loss is sufficient if paired with the right sampling strategy.
Example selection techniques are relatively less studied. For the contrastive loss it is common to select from all posible pairs at random [hadsell2006dimensionality, chopra2005learning, bell2015learning], and sometimes with hard negative mining [simo2015discriminative]. For the triplet loss, semi-hard negative mining, first used in FaceNet [facenet], is widely adopted [oh2016deep, parkhi2015deep]. Sampling has been studied for stochastic optimization [zhang2015stochastic] with the goal of accelerating convergence to the same global loss function. In contrast, in embedding learning the sampling actually changes the overall loss function considered. In this paper we show how sampling affects the real-world performance of deep embedding learning.
Let be an embedding of a datapoint , where is a differentiable deep network with parameters . Often is normalized to have unit length for training stability [facenet]. Our goal is to learn an embedding that keeps similar data points close, while pushing dissimilar datapoints apart. Formally we define the distance between two datapoints as , where denotes the Euclidean norm. For any positive pair of datapoints this distance should be small, and for negative pair it should be large.
The contrastive loss directly optimizes this distance by encouraging all positive distances to approach , while keeping negative distances above a certain threshold:
One drawback of the contrastive loss is that we have to select a constant margin for all pairs of negative samples. This implies that visually diverse classes are embedded in the same small space as visually similar ones. The embedding space does not allow for distortions.
In contrast the triplet loss merely tries to keep all positives closer to any negatives for each example:
This formulation allows the embedding space to be arbitrarily distorted and does not impose a constant margin .
From the risk minimization perspective, one might aim at optimizing the aggregate loss over all pairs or triples respectively. That is
This is computationally infeasible. Moreover, once the network converges, most samples contribute in a minor way as very few of the negative margins are violated.
This lead to the emergence of many heuristics to accelerate convergence. For the contrastive loss, hard negative mining usually offers faster convergence. For the triplet loss, it is less obvious, as hard negative mining often leads to collapsed models, i.e. all images have the same embedding. FaceNet [facenet] thus proposed to use a somewhat mysterious semi-hard negative mining: given an anchor and a positive example , obtain a negative instance via
within a batch. This yields a violating example that is fairly hard but not too hard. Batch construction also matters. In order to obtain more informative triplets, FaceNet uses a batch size of and ensures that each identity has roughly images in a batch [facenet]. Even how to best select triplets within a batch is unclear. Parkhi \etal [parkhi2015deep] use online selection, so that only one triplet is sampled for every pair. OpenFace [amos2016openface] employs offline triplet selection, so that a batch has of images as anchors, positives, and negatives respectively.
In short, sampling matters. It implicitly defines a rather heuristic objective function by weighting samples. Such an approach makes it hard to reproduce and extend the insights to different datasets, different optimization frameworks or different architectures. In the next section, we analyze some of these techniques, and explain why they offer better results. We then propose a new sampling strategy that outperforms current state of the art.
4 Distance Weighted Margin-Based Loss
To understand what happens when sampling negative uniformly, recall that our embeddings are typically constrained to the -dimensional unit sphere for large . Consider the situation where the points are uniformly distributed on the sphere. In this case, the distribution of pairwise distances follows
See [sphere] for a derivation. \figrefdistortion shows concentration of measure occurring. In fact, in high dimensional space, approaches . In other words, if negative examples are scattered uniformly, and we sample them randomly, we are likely to obtain examples that are -away. For thresholds less than , this induces no loss, and thus no progress for learning. Learned embeddings follow a very similar distribution, and thus the same reasoning applies.
Sampling negative examples that are too hard causes a different issue. Consider a negative pair or a triplet . The gradient with respect to the negative example is in the form of
for some function and . Note that the first term determines the direction of the gradient. A problem arises when is small, and our estimates of embedding are noisy. Given enough noise introduced by the training algorithm, direction is dominated by noise. \figrefvariance shows the nuclear norm of the covariance matrix for the direction of gradient with . We can see that when negative examples are too close/hard, the gradient has high variance and it has low signal to noise ratio. At the same time random samples are often too far apart to yield a good signal.
Distance weighted sampling.
We thus propose a new sampling distribution that corrects the bias while controlling the variance. Specifically, we sample uniformly according to distance, i.e. sampling with weights . This gives us examples which are spread out instead of being clustered around a small region. To avoid noisy samples, we clip the weighted sampling. Formally, given an anchor example , distance weighted sampling samples negative pair with
sampling_methods compares the simulated examples drawn from different strategies along with their variance of gradients. Hard negative mining always offers examples in the high-variance region. This leads to noisy gradients that cannot effectively push two examples apart, and consequently a collapsed model. Random sampling yields only easy examples that induce no loss. Semi-hard negative mining finds a narrow set in between. While it might converge quickly at the beginning, at some point no examples are left within the band, and the network will stop making progress. FaceNet reports a consistent finding: the decrease of loss slows down drastically after some point, and their final system took 80 days to train [facenet]. Distance weighted sampling offers a wide range of examples, and thus steadily produce informative examples while controlling the variance. In \secrefexp, we will see that distance weighted sampling brings performance improvements in almost all loss functions tested. Of course sampling only solves half of the problem, but it puts us in a position to analyze various loss functions.
contrast and \figreftriplet_l22 depict the contrastive loss and the triplet loss. There are two key differences, which in general explain why the triplet loss outperforms contrastive loss: The triplet loss does not assume a predefined threshold to separate similar and dissimilar images. Instead, it enjoys the flexibility to distort the space to tolerate outliers, and to adapt to different levels of intra-class variance for different classes. Second, the triplet loss only requires positive examples to be closer than negative examples, while the contrastive loss spends efforts on gathering all positive examples as close together as possible. The latter is not necessary. After all, maintaining correct relative relationship is sufficient for most applications, including image retrieval, clustering, and verification.
On the other hand, in \figreftriplet_l22 we also observe the concave shape of the loss function for negative examples in the triplet loss. In particular, note that for hard negatives (with small ), the gradient with respective to negative example is approaching zero. It is not hard to see why hard negative mining results in a collapsed model in this case: it gives large attracting gradients from hard positive pairs, but small repelling gradients from hard negative pairs, so all points are eventually gathered to the same point. To make the loss stable for examples from all distances, one simple remedy is to use instead of , i.e.
triplet_l2 presents the loss function. Now its gradients with respect to any embedding will always have length one. See e.g. [hazan2015beyond, levy2016power] for more discussions about the benefits of using gradients of a fixed length. This simple fix together with distance weighted sampling already outperforms the traditional triplet loss, as shown in \secrefexp.
Margin based loss.
These observations motivate our design of a loss function which enjoys the flexibility of the triplet loss, has a shape suitable for examples from all distances, while offering the computational efficiency of a contrastive loss. The basic idea can be traced back to the insight that in ordinal regression only the relative order of scores matters [Joachims02]. That is, we only need to know the crossover between both sets. Isotonic regression exploits this by estimating such a threshold separately and then penalizes scores relative to the threshold. We use the same trick, now applied to pairwise distances rather than score functions. The adaptive margin based loss is defined as
Here is a variable that determines the boundary between positive and negative pairs, controls the margin of separation, and . \figrefmargin_loss visualizes this new loss function. We can see that it relaxes the constraint on positive examples from contrastive loss. It effectively imposes a large margin loss on the shifted distance . This loss is very similar to a support vector classifier (SVC) [svm].
To enjoy the flexibility as a triplet loss, we need a more flexible boundary parameter which depends on class-specific and example-specific terms.
In particular, the example-specific offset plays the same role as the threshold in a triple loss. It is infeasible to manually select all the s and s. Instead, we would like to jointly learn these parameters. Fortunately, the gradient of can be easily calculated as
It is clear that larger values of are more desirable, since they amount to a better use of the embedding space. Hence, to regularize , we incorporate a hyperparameter , and it leads to the optimization problem
Here adjusts the difference between the number of points that violate the margin on the left and on the right. This can be seen by observing that their gradients need to cancel out at an optimal . Note that the use of here is very similar to the -trick in -SVM [scholkopf2000new].
Relationship to isotonic regression.
Optimizing the margin based loss can be viewed as solving a ranking problem for distances. Technically it shares similarity with learning-to-rank problems in information retrieval [zheng2008query, MoonZhengSmolaEtAl]. To see this first note at optimal , the empirical risk can be written as
One can show that , where s are the solution to
where , and . This is an isotonic regression defined on absolute error. We see that the margin based loss is the amount of “minimum-effort” updates to maintain relative orders. It focuses on the relative relationships, i.e. focusing on the separation of positive-pair distances and the negative-pair distances. This is in contrast to traditional loss functions such as the contrastive loss, where losses are defined relative to a predefined threshold.
We evaluate our method on image retrieval, clustering and verification. For image retrieval and clustering, we use the Stanford Online Products [oh2016deep], CARS196 [krause20133d], and the CUB200-2011 [WelinderEtal2010] datasets, following the experimental setup of Song \etal[oh2016deep]. The Stanford Online Product dataset contains 120,053 images of 22,634 categories. The first 11,318 categories are used for training, and the remaining are used for testing. The CARS196 dataset contains 16,185 car images of 196 models. We use the first 98 models for training, and the remaining for testing. The CUB200-2011 dataset contains 11,788 bird images of 200 species. The first 100 species are used for training, the remainder for testing.
We evaluate the quality of image retrieval based on the standard Recall@K metric, following Song \etal[oh2016deep]. We use NMI score, , to evaluate the quality of clustering alignments , given a ground-truth clustering . Here and denotes mutual information and entropy respectively. We use K-means algorithm for clustering.
For verification, we train our model on the largest publicly available face dataset, CASIA-WebFace [yi2014learning], and evaluate on the standard LFW [LFWTech] dataset. The VGG face dataset [parkhi2015deep] is bigger, but many of its links have expired. The CASIA-WebFace dataset contains 494,414 images of 10,575 people. The LFW dataset consists of 13,233 images of 5,749 people. Its verification benchmark contains 6,000 verification pairs, split into 10 subsets. We select the verification threshold for one split based on the remaining nine splits.
Unless stated otherwise, we use an embedding size of 128 and an input image size of in all experiments. All models are trained using Adam [kingma2014adam] with a batch size of 200 for face verification, 80 for Stanford Online Products, and 128 for other experiments. The network architecture follows ResNet-50 [he2016deep]. To accelerate training, we use a simplified version of ResNet-50 in the face verification experiments. Specifically, we use only filters in the 5 stages respectively, instead of the originally proposed filters. We did not observe any obvious performance degradations due to the change. Horizontal mirroring and random crops from are used for data augmentation. During testing we use a single center crop. Face images are aligned by MTCNN [zhang2016joint]. When alignment fails, we use a center crop. For the margin based loss we initialize and for all experiments.
Note that some previous papers use the provided bounding boxes while others do not. To fairly compare with previous methods, we evaluate our methods on both the original images and the ones cropped by bounding boxes. For the CARS196 dataset we scale the cropped images to . For CUB200, we scale and pad the images such that their longer side is pixels, keeping the aspect ratio fixed.
Our batch construction follows FaceNet [facenet]. We use positive images per class in a batch. All positive pairs within a batch are sampled. For each example in a positive pair, we sample one negative pair. This ensures that the number of positive and negative pairs are balanced, and every example belongs to the same number of positive pairs and the same number of negative pairs.
5.1 Ablation study
We start by understanding the effect of the loss function, the adaptive margin and the specific functional choice. We focus on the Stanford Online Products dataset, as it is the largest among the three image retrieval datasets. Note that image retrieval favors triplet losses over contrastive losses, since only relative relationships matter. Here all models are trained from scratch. Since different methods converge at different rates, we train all methods for 100 epochs, and report the performance at their best epoch rather than at the end of training.
We compare random sampling and semi-hard negative mining to our distance weighted sampling. For semi-hard sampling, there is no natural choice of a distance lower bound for pairwise loss functions. In this experiment we use a lower bound of to simulate the positive distance in triplet loss. We consider the contrastive loss, the triplet loss and our margin based loss. By random sampling, we refer to uniform sampling from all positive and negative pairs. Since such a definition is not applicable for triplet losses, we test only the contrastive and margin based losses.
Results are presented in \tabrefablation. We see that given the same loss function, different sampling distributions lead to very different performance. In particular, while the contrastive loss yields considerably worse results than triplet loss with random sampling, its performance significantly improves when using a sampling procedure similar to triplet loss. This evidence disproves a common misunderstanding of contrastive loss vs. triplet loss: the strength of triplet loss comes not just from the loss function itself, but more importantly from the accompanying sampling methods. In addition, distance weighted sampling consistently offers a performance boost for almost all loss functions. The only exception is the contrastive loss. We found it to be very sensitive to its hyperparameters. While we found good hyperparameters for random and semi-hard sampling, we were not able to find a well-performing hyperparameter for the distance weighted sampling yet. On the other hand, margin based loss automatically learns a suitable offset and trains well. Notably, the margin based loss outperforms other loss functions by a large margin irrespective of sampling strategies. We also try pre-training our model using ILSVRC 2012-CLS [deng2009imagenet] dataset, as is commonly done in prior work [oh2016deep, bell2015learning]. Pre-training offers a boost in recall. In the following sections we focus on pre-trained models for fair comparison.
|Query||Triplet (R@1=49.7)||Margin (R@1=61.7)|
Next, we qualitatively evaluate these methods. \figrefir presents the retrieval results on randomly picked query images. We can see that triplet loss generally offers reasonable results, but makes mistakes in some cases. On the other hand, our method gives much more accurate results.
To evaluate the gains obtained by learning a flexible boundary , we compare models using a fixed to models using learned s. The results are summarized in \tabrefbeta. We see that the use of more flexibly class-specific indeed offers advantages over various values of fixed . We also test using example-specific , but the experiments are inconclusive. We conjecture that learning example-specific might have introduced too many parameters and caused over-fitting.
Next, we measure the stability of different loss functions when using different batch construction. Specifically, we change the number of images per class in a batch and see how it impacts the solutions. For this purpose, we experiment with face verification and use the optimal verification boundary on the validation set as a summary of the solution. The results are summarized in \figrefstability. We see that the triplet loss converges to different solutions when using different batch constructions. In addition, we observe large fluctuations in the early stage, indicating unstable training. On the other hand, the margin based loss is robust, it always converges to the roughly the same geometry.
We further analyze the effects of sampling on the convergence speed. We compare margin based loss using distance weighted sampling with the two most commonly used deep embedding approaches: triplet loss with semi-hard sampling and contrastive loss with random sampling. The learning curves are shown in \figrefalg_conv. We see that triplet loss trained with semi-hard negative mining converges slower as it ignores too many examples. Contrastive loss with random sampling converges even slower. Distance weighted sampling, which uses more informative and stable examples, converges faster and more accurately.
5.2 Quantitative Results
We now compare our approach to other state-of-the-art methods. Image retrieval and clustering results are summarized in \tabrefsop, 5 and 5. We can see that our model achieves the best performance in all three datasets. In particular, margin based loss outperforms extensions of triplet loss, such as LiftedStruct [oh2016deep], StructClustering [song2016learnable], N-pairs [sohn2016improved], and PDDM [huang2016local]. It also outperforms histogram loss [ustinova2016learning], which requires computing similarity histograms. Also note that our model uses only one 128-dimensional embedding for each image. This is much more concise and simpler than HDC [yuan2016hard], which uses 3 embedding vectors for each image.
lfw presents results for face verification. Our model achieves the best accuracy among all models trained on CASIA-WebFace. Also note that here our method outperforms models using a wide range of training procedures. MFM [wu2015lightened] use a softmax classification loss. CASIA [yi2014learning] use a combination of softmax loss and contrastive loss. N-pair [sohn2016improved] use a more costly loss function that is defined on all pairs in a batch. We also list a few other state-of-the-art results which are not comparable purely for reference. DeepID2 [sun2014deep] and DeepID3 [sun2015deepid3] use 25 networks on 25 face regions based on positions of facial landmarks. When trained using only one network, their performance degrades significantly. Other models such as FaceNet [facenet] and DeepFace [taigman2014deepface] are trained on huge private datasets.
Overall, our model achieves the best results on all datasets among all compared methods. Notably, our method uses the simplest loss function among all — a simple variant of contrastive loss.
We demonstrated that sampling matters as much or more than loss functions in deep embedding learning. This should not come as a surprise, since the implicitly defined loss function is (quite obviously) a sample weighted object.
Our new distance weighted sampling yields a performance improvement for multiple loss functions. In addition, we analyze and provide a simple margin-based loss that relaxes unnecessary constraints from traditional contrastive loss and enjoys the flexibility of the triplet loss. We show that distance weighted sampling and the margin based loss significantly outperform all other loss functions.