Learning Embeddings for Product Visual Search with Triplet Loss and Online Sampling
In this paper, we propose learning an embedding function for content-based image retrieval within the e-commerce domain using the triplet loss and an online sampling method that constructs triplets from within a minibatch. We compare our method to several strong baselines as well as recent works on the DeepFashion and Stanford Online Product datasets. Our approach significantly outperforms the state-of-the-art on the DeepFashion dataset. With a modification to favor sampling minibatches from a single product category, the same approach demonstrates competitive results when compared to the state-of-the-art for the Stanford Online Products dataset.
Visual search is an increasingly popular tool that enables users to quickly find similar products from large online product catalogs. At its core, a visual search system embeds both query and catalog images into a common feature space and, at search-time, uses this feature space to retrieve a query image’s k-nearest neighbors in the catalog. Whether such a system is able to surface relevant results for any given query depends heavily upon the quality of the embedding function used to represent both the query and catalog images.
Several approaches have been proposed in the shopping domain to learn a retrieval model from labeled pairs of images of the same item [\citeauthoryearHuang et al.2015, \citeauthoryearKiapour et al.2015, \citeauthoryearLiu et al.2016, \citeauthoryearGajic and Baldrich2018]. In these methods, a model is trained to embed images such that “positive” pairs of images with the same items are represented with features closer to each other than “negative” pairs of images containing different items.
In this work, we develop an image retrieval model using a deep convolutional neural network trained with a triplet loss. We show that a variant of “batch-hard” triplet sampling [\citeauthoryearHermans, Beyer, and Leibe2017] allows this model to significantly improve on the state of the art for the DeepFashion consumer-to-shop clothes retrieval dataset [\citeauthoryearLiu et al.2016]. We demonstrate via ablation that the model improves further if we include all matched pairs during training, rather than only considering pairs that cross the query and catalog image sets. In addition, we evaluate our method on the Stanford Online Product dataset [\citeauthoryearSong et al.2016] and show that we can obtain competitive results with current state-of-the-art methods by forming minibatches wherein all examples come from the same class for a portion of training time. We further analyze through exploratory analysis when it is advantageous to use such a sampling approach.
Content-based Image Retrieval has been a well studied field of research. Babenko et al. \shortciteBabenko2014 show that features extracted from the penultimate layers of convolutional neural networks previously trained on ImageNet [\citeauthoryearRussakovsky et al.2015] outperform handcrafted features for retrieving semantically similar images when using euclidean distance as the distance metric. They further show that it is possible to train these same neural networks using classification data to improve retrieval performance for a given domain of interest.
Deep Metric Learning: While training to perform surrogate tasks such as classification has been shown to be an effective way to learn features for image retrieval [\citeauthoryearBabenko et al.2014], Deep Metric Learning (DML) methods aim to directly optimize a neural network to project data such that pairs of examples that are labeled “similar” are represented with features that are close together in metric space and pairs of examples that are labeled “disimilar” to be far apart in metric space [\citeauthoryearChopra, Hadsell, and LeCun2005]. These approaches have been used to learn embeddings in a wide variety of tasks such as Face Verification [\citeauthoryearChopra, Hadsell, and LeCun2005], Person Re-identification [\citeauthoryearHermans, Beyer, and Leibe2017], and in our case, visual product search and fashion retrieval [\citeauthoryearKiapour et al.2015, \citeauthoryearLiu et al.2016, \citeauthoryearHuang et al.2015]. Our work adopts a deep metric learning based approach to embed both query and catalog images into a metric space for image retrieval using the Triplet loss [\citeauthoryearSchroff, Kalenichenko, and Philbin2015] for training.
Model and Sampling Approach
For a query image , we retrieve the catalog images that are most similar to by the cosine similarity of their embeddings:
In our model, the embedding function is a ResNet-50-v2 [\citeauthoryearHe et al.2016] neural network that was pre-trained on ImageNet [\citeauthoryearRussakovsky et al.2015], with the classification layer removed. The network is fine-tuned on the DeepFashion training set using the standard triplet loss [\citeauthoryearSchroff, Kalenichenko, and Philbin2015]:
Here is the set of (anchor, positive, negative) triplets and is a fixed margin hyperparameter. We set to 0.1 for the results shown in this paper. The choice of base network architecture and loss function follow [\citeauthoryearGajic and Baldrich2018].
The key to the success of our approach is in sampling effectively from the set of triplets . The number of triplets is cubic in the size of the dataset and therefore impractical to cover entirely during training. We observe empirically that uniform sampling leads to poor performance, worse than the pretrained baseline for the hyperparameter settings we tried. Various sampling methods have been proposed to combat this issue [\citeauthoryearSchroff, Kalenichenko, and Philbin2015, \citeauthoryearHuang et al.2015, \citeauthoryearGajic and Baldrich2018]. We adapt the “batch-hard” sampling technique proposed in [\citeauthoryearHermans, Beyer, and Leibe2017] in the person re-ID setting. We construct a minibatch by first drawing “anchor” images uniformly at random. A “positive” image for each anchor is chosen at random from the set of images that contain the same item as the anchor. For each pair in the minibatch, we form one triplet by setting the negative to be the “positive” from another pair within the batch, that is closest to the anchor :
The loss and its gradient for the minibatch are evaluated only on these selected triplets. The selection process can also be extended to the anchors when appropriate for the dataset. Our method differs from “batch-hard” in that we sample positives uniformly, with only the negatives selected to be hard. The original batch-hard method instead created minibatches with several images of each item in the minibatch and then selected positives as well as negatives to find the hardest triplets.
We evaluated our methods on two publicly available image retrieval datasets with images from e-commerce websites.
DeepFashion Consumer-to-Shop Clothes Retrieval [\citeauthoryearLiu et al.2016] is a popular dataset for evaluating the image retrieval task in the fashion domain. The Consumer-to-Shop Clothes Retrieval benchmark within this dataset contains 251,361 consumer-to-shop image pairs from online retailer Mogujie. Each image has a bounding box for one of 33,881 items, with each item belonging to one of 23 high-level clothing categories. The training/validation/test splits of the dataset consist of non-overlapping subsets of the clothing items. The photos exhibit a wide diversity in appearance. Some photos come from an “in-the-wild” setting while others come from online catalogs where the clothing items often have pristine white backgrounds. The consumer images are all “in-the-wild” while the shop images tend to provide more ideal views of the labeled clothing item. We use the item labels to form the anchor-positive pairs for training our image retrieval model as described in Model and Sampling Approach.
In figure 3, we use the recall at k metric [\citeauthoryearKiapour et al.2015] to compare our model to several other methods commonly used to learn image feature embeddings for retrieval. All models shown in this figure use the same Resnet50-v2 base architecture. Our method (blue) outperforms generic embeddings extracted from a network that had been pre-trained on ImageNet [\citeauthoryearRussakovsky et al.2015] (orange). It also outperforms embeddings extracted from a network trained to predict one of the 23 DeepFashion classes using a softmax cross-entropy loss (red). The previous state-of-the-art method proposed by Gajic and Baldrich [\citeauthoryearGajic and Baldrich2018] (purple) uses a triplet loss but relies on an offline hard triplet mining approach, unlike our online batch-hard sampling. We show that our approach can retrieve an exact item match in the first 20 retrieved images over 65% of the time as compared to Gajic and Baldrich’s 45% retrieval accuracy. Our approach represents a 44% relative improvement over the previous state-of-the-art result.
Although the model is tested by querying with “consumer” images and retrieving from a separate set of “shop” images, our model improves if we ignore the distinction between these domains during training. The retrieval performance of our model trained only using cross-domain pairs is shown in green in figure 3 for comparison.
DeepFashion provides a bounding box for each image, indicating the item the image is labeled as containing. We cropped to these bounding boxes for our primary experiments, but we also trained and evaluated the model using the whole images, ignoring the provided bounding boxes. The results of this experiment are shown in brown in figure 3.
In figure 1, we show some retrieval results from the DeepFashion dataset using our trained batch-hard triplet model. Our model is able to successfully retrieve photos across domains, that is, it can find the exact right product in a catalog of photos even when the query comes from a natural scene-image of lower quality. In many cases, it is also robust to pose changes and oblique capture angle differences between query and catalog photos. We do observe that the model performs better in retrieval for some categories more so than others. This is highly apparent in categories such as “jeans”, where the visual details of the intra-class examples are not as visually distinguishable as they are with “blouses”.
Stanford Online Products dataset
We also evaluated our methods on the Stanford Online Product (SOP) dataset [\citeauthoryearSong et al.2016], which contains 120,053 images of 22,634 items from eBay, roughly equally distributed among 12 categories such as “stapler” and “bicycle.” Unlike DeepFashion, SOP is not divided into separate domains for query and retrieval images; rather, the model should retrieve other images of the same product from the same set as the query was drawn from. Figure 4 shows results for our model and several other published results on SOP. Most previous work on this dataset reports only the recall at 1, so we compare based on this metric as well.
The first result with “ordinary sampling” uses exactly the same model and training procedure as our best-performing model on DeepFashion. Although the datasets differ in that DeepFashion has separate “consumer” and “shop” domains, we showed that our model actually improves when ignoring these at training time. Therefore it is not surprising that the model also works well on SOP. However, the model trained with our ordinary sampling method significantly underperforms the state of the art on this dataset.
The second result with “within-class sampling” uses the same model except that the minibatches are constructed differently to improve the triplet sampling. Some fraction of minibatches – 80% for the results shown here – are constructed entirely of images from one of the SOP categories (“toaster,” “chair,” etc.). The other minibatches are constructed as described in section Model and Sampling Approach, to ensure that the model does not confuse images from different classes.
Figure 5 illustrates why this within-class sampling improves the performance of our model on Stanford Online Products, while we found it had no benefit or even degraded performance on DeepFashion. The image categories in SOP are quite distinct, and the model rarely retrieves images from the wrong class as the first predicted result. For the confusion matrices in figure 5, we treated the model as a classifier where the predicted class is the class of the first retrieved image. Even in the classes that generate the most confusion, an image is still matched with an image of the same class first 87% of the time. In contrast, the model trained on DeepFashion makes mistakes outside the item’s category as much as 31% of the time, in the case of dresses. Note that we constructed the three DeepFashion categories in figure 5 by combining classes from among the 23 provided by the metadata whenever the distinction was unclear to us; evaluating classification accuracy in the same way on the original classes gives 59.8% overall.
The high classification accuracy in the case of Stanford Online Products suggests that we can greatly increase the efficiency of finding triplets that contribute to the loss by forming minibatches from within the categories. Figure 6 shows that this technique is effective on SOP, increasing the fraction of batch-hard triplets with nonzero loss by as much as 20 percentage points. No such effect is observed in the case of DeepFashion, but the fraction of non-satisfied triplets is still larger on DeepFashion.
Effect of batch size
It is important to use sufficiently large minibatches to allow our method to find enough useful triplets to accumulate a meaningful signal for the gradient. We found that very small batch sizes led to poor performance, as shown in Figure 7 for models trained on Stanford Online Products with 80% of batches sampled from within a class. The benefit of increasing the batch size diminishes after batches of about 32 anchor-positive pairs, with no significant improvement seen for batch sizes larger than 48 pairs. We observed similar results on the DeepFashion dataset.
We attribute the strength of our results on the DeepFashion Consumer-to-Shop Clothes Retrieval dataset compared to prior work to the effectiveness of our sampling approach. While random triplet sampling appears to fail because most triplets are do not provide a useful contribution to the gradient, the offline hard negative mining used in the previous state-of-the-art on DeepFashion [\citeauthoryearGajic and Baldrich2018] may be limited for the opposite reason. That is, selecting too strongly for hard negatives may overemphasize mislabeled negatives and others that are so similar to the anchor that any differences the network finds may fail to generalize, leading to noisier gradients or a large train/test gap.
We further improved our results by ignoring the distinction between the “consumer” and “shop” domains in DeepFashion at training time. We conjecture that this performance gain arises simply from increasing the effective size of the dataset by including single-domain pairs, and that the two domains in DeepFashion are not sufficiently distinct that the model would benefit much from specializing to cross-domain comparisons as was done in [\citeauthoryearGajic and Baldrich2018].
The Recall-at-k results on the Stanford Online Products dataset are generally greater for a given model than on DeepFashion. We found that a more strongly selective sampling strategy, finding hard triplets from within single-class batches, improved results on SOP but not on DeepFashion. These results suggest that extra sampling strategies beyond batch-hard sampling may become more useful as the dataset gets “easier” in the sense that a greater portion of the triplet constraints are easily satisfied. Our within-class sampling strategy in particular is likely to be useful whenever a model easily distinguishes the classes. In this way, additional labeling beyond the item identities may be leveraged indirectly to improve the training of a visual search model.
Although our results on Stanford Online Products are not clearly better than the state-of-the-art, our method is simple and general and achieves similar performance to Proxy NCA [\citeauthoryearMovshovitz-Attias et al.2017] and other triplet methods such as [\citeauthoryearVo and Hays2018]. The latter paper focused on finding the optimal layer of the network to use as embeddings at test time, which may be considered independently of the sampling strategy we propose. As far as we know, the only published result that surpasses our recall-at-1 performance on Stanford Online Products used an ensemble method that could also be combined with our sampling strategies.
We have shown that a simple model based on a standard deep convolutional neural network and triplet loss, when trained with an effective online sampling technique, performs well on a visual search task. Our basic method surpasses the previous state of the art on the DeepFashion Consumer-to-Shop Clothes Retrieval dataset, and we observe that allowing consumer-consumer and shop-shop matches at training time improves performance further. We also evaluated on the Stanford Online Products dataset, and we achieved results similar to the state of the art after modifying our sampling strategy to favor minibatches from within a product category. Interestingly, this modified sampling strategy did not provide a similar benefit on DeepFashion. While our basic method, based on batch-hard sampling of triplets, appears to generalize well, specific adjustments may be appropriate on some datasets and not others. Future work may also combine the insights of this work with those of previous work to improve the state of the art for product visual search.
- Babenko, A.; Slesarev, A.; Chigorin, A.; and Lempitsky, V. S. 2014. Neural codes for image retrieval. CoRR abs/1404.1777.
- Chopra, S.; Hadsell, R.; and LeCun, Y. 2005. Learning a similarity metric discriminatively, with application to face verification. In Computer Vision and Pattern Recognition, 2005. CVPR 2005. IEEE Computer Society Conference on, volume 1, 539–546. IEEE.
- Gajic, B., and Baldrich, R. 2018. Cross-domain fashion image retrieval. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops.
- He, K.; Zhang, X.; Ren, S.; and Sun, J. 2016. Identity Mappings in Deep Residual Networks. 1–15.
- Hermans, A.; Beyer, L.; and Leibe, B. 2017. In Defense of the Triplet Loss for Person Re-Identification.
- Huang, J.; Feris, R.; Chen, Q.; and Yan, S. 2015. Cross-domain image retrieval with a dual attribute-aware ranking network. Proceedings of the IEEE International Conference on Computer Vision 2015 Inter:1062–1070.
- Kiapour, M. H.; Han, X.; Lazebnik, S.; Berg, A. C.; and Berg, T. L. 2015. Where to buy it: Matching street clothing photos in online shops. Proceedings of the IEEE International Conference on Computer Vision 2015 Inter:3343–3351.
- Kim, W.; Goyal, B.; Chawla, K.; Lee, J.; and Kwon, K. 2018. Attention-based ensemble for deep metric learning. arXiv preprint arXiv:1804.00382.
- Liu, Z.; Luo, P.; Qiu, S.; Wang, X.; and Tang, X. 2016. DeepFashion: Powering Robust Clothes Recognition and Retrieval with Rich Annotations. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (1):1096–1104.
- Movshovitz-Attias, Y.; Toshev, A.; Leung, T. K.; Ioffe, S.; and Singh, S. 2017. No Fuss Distance Metric Learning using Proxies. Iccv 360–368.
- Russakovsky, O.; Deng, J.; Su, H.; Krause, J.; Satheesh, S.; Ma, S.; Huang, Z.; Karpathy, A.; Khosla, A.; Bernstein, M.; Berg, A. C.; and Fei-Fei, L. 2015. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV) 115(3):211–252.
- Schroff, F.; Kalenichenko, D.; and Philbin, J. 2015. FaceNet: A unified embedding for face recognition and clustering. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition 07-12-June:815–823.
- Song, H. O.; Xiang, Y.; Jegelka, S.; and Savarese, S. 2016. Deep metric learning via lifted structured feature embedding. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
- Vo, N., and Hays, J. 2018. Generalization in Metric Learning: Should the Embedding Layer be the Embedding Layer?
- Wu, C.-Y.; Manmatha, R.; Smola, A. J.; and Krähenbühl, P. 2017. Sampling matters in deep embedding learning. In Proc. IEEE International Conference on Computer Vision (ICCV).