PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment
Despite the great progress made by deep CNNs in image semantic segmentation, they typically require a large number of densely-annotated images for training and are difficult to generalize to unseen object categories. Few-shot segmentation has thus been developed to learn to perform segmentation from only a few annotated examples. In this paper, we tackle the challenging few-shot segmentation problem from a metric learning perspective and present PANet, a novel prototype alignment network to better utilize the information of the support set. Our PANet learns class-specific prototype representations from a few support images within an embedding space and then performs segmentation over the query images through matching each pixel to the learned prototypes. With non-parametric metric learning, PANet offers high-quality prototypes that are representative for each semantic class and meanwhile discriminative for different classes. Moreover, PANet introduces a prototype alignment regularization between support and query. With this, PANet fully exploits knowledge from the support and provides better generalization on few-shot segmentation. Significantly, our model achieves the mIoU score of 48.1% and 55.7% on PASCAL-5i for 1-shot and 5-shot settings respectively, surpassing the state-of-the-art method by 1.8% and 8.6%.
Deep learning has greatly advanced the development of semantic segmentation with a number of CNN based architectures like FCN , SegNet , DeepLab  and PSPNet . However, training these models typically requires large numbers of images with pixel-level annotations which are expensive to obtain. Semi- and weakly-supervised learning methods [26, 3, 9, 15] alleviate such requirements but still need many weakly annotated training images. Besides their hunger for training data, these models also suffer rather poor generalizability to unseen classes. To deal with the aforementioned challenges, few-shot learning, which learns new concepts from a few annotated examples, has been actively explored, mostly concentrating on image classification [25, 23, 24, 18, 6, 20, 12, 14] and a few targeting at segmentation tasks [21, 16, 4, 28, 4, 8].
Existing few-shot segmentation methods generally learn from a handful of support images and then feed learned knowledge into a parametric module for segmenting the query. However, such schemes have two drawbacks and thus generalize unsatisfactorily. First, they do not differentiate the knowledge extraction and segmentation process, which may be problematic since the segmentation model representation is mixed with the semantic features of the support. We therefore propose to separate these two parts as prototype extraction and non-parametric metric learning. The prototypes are optimized to be compact and robust representations for each semantic class and the non-parametric metric learning performs segmentation through pixel-level matching within the embedding space. Moreover, instead of using the annotations of the support only for masking as in previous methods, we propose to leverage them also for supervising the few-shot learning process. To this end, we introduce a novel prototype alignment regularization by performing the few-shot segmentation in a reverse direction. Namely, the query image together with its predicted mask is considered as a new support set and used to segment the previous support images. In this way, the model is encouraged to generate more consistent prototypes between support and query, offering better generalization performance.
Accordingly, we develop a Prototype Alignment Network (PANet) to tackle few-shot segmentation, as shown in Figure 1. PANet first embeds different foreground objects and background into different prototypes via a shared feature extractor. In this way, each learned prototype is representative for the corresponding class and meanwhile is sufficiently distinguishable from other classes. Then, each pixel of the query image is labeled by referring to the class-specific prototypes nearest to its embedding representation. We find that even with only one support image per class, PANet can provide satisfactory segmentation results, outperforming the state-of-the-arts. Furthermore, it imposes a prototype alignment regularization by forming a new support set with the query image and its predicted mask and performing segmentation on the original support set. We find this indeed encourages the prototypes generated from the queries to align well with those of the supports. Note that the model is regularized only in training and the query images should be not confused with the testing images.
The structure design of the proposed PANet has several advantages. First, it introduces no extra learnable parameters and thus is less prone to over-fitting. Second, within PANet, the prototype embedding and prediction are performed on the computed feature maps and therefore segmentation requires no extra passes through the network. In addition, as the regularization is only imposed in training, the computation cost for inference does not increase.
Our few-shot segmentation model is a generic one. Any network with a fully convolutional structure can be used as the feature extractor. It also learns well from weaker annotations, \eg, bounding boxes or scribbles, as shown in experiments. To sum up, the contributions of this work are:
We propose a simple yet effective PANet for few-shot segmentation. The model exploits metric learning over prototypes, which differs from most existing works that adopt a parametric classification architecture.
We propose a novel prototype alignment regularization to fully exploit the support knowledge to improve the few-shot learning.
Our model can be directly applied to learning from a few examples with weak annotations.
Our PANet achieves mIoU of 48.1% and 55.7% on PASCAL-5i for 1-shot and 5-shot settings, outperforming state-of-the-arts by a margin up to 8.6 %.
2 Related work
Semantic segmentation aims to classify each pixel of an image into a set of predefined semantic classes. Recent methods are mainly based on deep convolutional neural networks [13, 10, 1, 29, 2]. For example, Long \etal  first adopted deep CNNs and proposed Fully Convolutional Network (FCN) which greatly improves segmentation performance. Dilated convolutions [27, 2] are widely used to increase the receptive field without losing spatial resolution. In this work, we follow the structure of FCN to perform dense prediction and also adopt dilated convolutions to enjoy a larger receptive field. Compared to models trained with full supervision, our model can generalize to new categories with only a handful of annotated data.
Few-shot learning targets at learning transferable knowledge across different tasks with only a few examples. Many methods have been proposed, such as methods based on metric learning [25, 23], learning the optimization process [18, 6] and applying graph-based methods [20, 12]. Vinyals \etal  encoded input into deep neural features and performed weighted nearest neighbor matching to classify unlabelled data. Snell \etal  proposed a Prototypical Network to represent each class with one feature vector (prototype). Sung \etal  used a separate module to directly learn the relation between support features and query features. Our model follows the Prototypical Network  and can be seen as an extension of it to dense prediction tasks, enjoying a simple design yet high performance.
Few-shot segmentation is receiving increasing interest recently. Shaban \etal  first proposed a model for few-shot segmentation using a conditioning branch to generate a set of parameters from the support set, which is then used to tune the segmentation process of the query set. Rakelly \etal  concatenated extracted support features with query ones and used a decoder to generate segmentation results. Zhang \etal  used masked average pooling to better extract foreground/background information from the support set. Hu \etal  explored guiding at multiple stages of the networks. These methods typically adopt a parametric module, which fuses information extracted from the support set and generates segmentation.
Dong \etal  also adopted the idea of prototypical networks and tackled few-shot segmentation using metric learning. However, the model is too complex, involving three training stages and complicated training configurations. Besides, their method extracts prototypes based on an image-level loss and uses prototypes as guidance to tune the segmentation of the query set rather than obtaining segmentation directly from metric learning. Comparatively, our model has a simpler design and is more similar to the Prototypical Network . Besides, we adopt late fusion  to incorporate the annotation masks, making it easier to generalize to cases with sparse or updating annotations.
3.1 Problem setting
We aim at obtaining a segmentation model that can learn fast to perform segmentation from only a few annotated images over new images from the same classes. As in previous works , we adopt the following model training and testing protocols. Suppose we are provided with images from two non-overlapping sets of classes and . The training set is constructed from and the test set is constructed from . We train the segmentation model on and evaluate on .
Both the training set and testing set consist of several episodes. Each episode is composed of a set of support images (with annotations) and a set of query images . Namely, and , where and denote the number of episodes for training and testing respectively.
Each training/testing episode instantiates a -way -shot segmentation learning task. Specifically, the support set has image, mask pairs per semantic class and there are in total different classes from for training and from for testing, \ie where and with . The query set contains image, mask pairs from the same set of classes as the support set. The model first extracts knowledge about the classes from the support set and then applies the learned knowledge to perform segmentation on the query set. As each episode contains different semantic classes, the model is trained to generalize well. After obtaining the segmentation model from the training set , we evaluate its few-shot segmentation performance on the test set across all the episodes. In particular, for each testing episode the segmentation model is evaluated on the query set given the support set .
3.2 Method overview
Different from existing few-shot segmentation methods which fuse the extracted support features with the query features to generate the segmentation results in a parametric way, our proposed model aims to learn and align compact and robust prototype representations for each semantic class in an embedding space. Then it performs segmentation within the embedding space via non-parametric metric learning.
As shown in Figure 2, our model learns to perform segmentation as follows. For each episode, it first embeds the support and query images into deep features by a shared backbone network. Then it applies the masked average pooling to obtain prototypes from the support set, as detailed in Section 3.3. Segmentation over the query images is performed by labeling each pixel as the class of the nearest prototype. A novel prototype alignment regularization (PAR) introduced in Section 3.5 is applied over the learning procedure to encourage the model to learn consistent embedding prototypes for the support and query.
We adopt a VGG-16  network as the feature extractor following conventions. The first 5 convolutional blocks in VGG-16 are kept for feature extraction and other layers are removed. The stride of maxpool4 layer is set to 1 for maintaining large spatial resolution. To increase the receptive field, the convolutions in conv5 block are replaced by dilated convolutions with dilation set to 2. As the proposed PAR introduces no extra learnable parameters, our network is trained end-to-end to optimize the weights of VGG-16 for learning a consistent embedding space.
3.3 Prototype learning
Our model learns representative and well-separated prototype representation for each semantic class, including the background, based on the prototypical network . Instead of averaging over the whole input image , PANet leverages the mask annotations over the support images to learn prototypes for foreground and background separately. There are two strategies to exploit the segmentation masks \ie, early fusion and late fusion . Early fusion masks the support images before feeding them into the feature extractor [21, 8, 4]. Late fusion directly masks over the feature maps to produce foreground/background features separately [28, 17]. In this work, we adopt the late fusion strategy since it keeps the input consistency for the shared feature extractor. Concretely, given a support set , let be the feature map output by the network for the image . Here indexes the class and indexes the support image. The prototype of class is computed via masked average pooling :
where indexes the spatial locations and is an indicator function, outputting value if the argument is true or otherwise. In addition, the prototype of background is computed by
The above prototypes are optimized end-to-end through non-parametric metric learning as explained below.
3.4 Non-parametric metric learning
We adopt a non-parametric metric learning method to learn the optimal prototypes and perform segmentation accordingly. Since segmentation can be seen as classification at each spatial location, we calculate the distance between the query feature vector at each spatial location with each computed prototype. Then we apply a softmax over the distances to produce a probability map over semantic classes (including background). Concretely, given a distance function , let and denote the query feature map. For each we have
The predicted segmentation mask is then given by
The distance function commonly adopts the cosine distance or squared Euclidean distance. Snell \etal  claimed using squared Euclidean distance greatly outperforms using cosine distance. However, Oreshkin \etal  attributed the improvement to interaction of the different scaling of the metrics with the softmax function. Multiplying the cosine distance by a factor can achieve comparable performance as using squared Euclidean distance. Empirically, we find that using cosine distance is more stable and gives better performance, possibly because it is bounded and thus easier to optimize. The multiplier is fixed at 20 since we find learning it yields little performance gain.
After computing the probability map for the query image via metric learning, we calculate the segmentation loss as follows:
where is the ground truth segmentation mask of the query image and is the total number of spatial locations. Optimizing the above loss will derive suitable prototypes for each class.
3.5 Prototype alignment regularization (PAR)
In previous works, the support annotations are used only for masking, which actually does not adequately exploit the support information for few-shot learning. In this subsection, we elaborate on the prototype alignment regularization (PAR) that exploits support information better to guide the few-shot learning procedure and helps enhance generalizability of the resulted model from a few examples.
Intuitively, if the model can predict a good segmentation mask for the query using prototypes extracted from the support, the prototypes learned from the query set based on the predicted masks should be able to segment support images well. Thus, PAR encourages the resulted segmentation model to perform few-shot learning in a reverse direction, \ie, taking the query and the predicted mask as the new support to learn to segment the support images. This imposes a mutual alignment between the prototypes of support and query images and learns richer knowledge from the support. Note all the support and query images here are from the training set .
Figure 2 illustrates PAR in details. After obtaining a segmentation prediction for the query image, we perform masked average pooling accordingly on the query features and obtain another set of prototypes , following Eqns. (1) and (2). Next, the non-parametric method introduced in Section 3.4 is used to predict the segmentation masks for the support images. The predictions are compared with the ground truth annotations to calculate a loss . The entire procedure for implementing PAR can be seen as swapping the support and query set. Concretely, within PAR, the segmentation probability of the support image is given by
and the loss is computed by
Without PAR, the information only flows one-way from the support set to the query set. By flowing the information back to the support set, we force the model to learn a consistent embedding space that aligns the query and support prototypes. The aligning effect of the proposed PAR is validated by experiments in Section 4.3.
The total loss for training our PANet model is thus
where serves as regularization strength and reduces to the model without PAR. In our experiments, we keep as since different values give little improvement. The whole training and testing procedures for PANet on few-shot segmentation are summarized in Algorithm 1.
3.6 Generalization to weaker annotations
Our model is generic and is directly applicable to other types of annotations. First, it accepts weaker annotations on the support set, such as scribbles and bounding boxes indicating the foreground objects of interest. Experiments in Section 4.4 show that even with weak annotations, our model is still able to extract robust prototypes from the support set and give comparably good segmentation results for the query images. Compared with pixel-level dense annotations, weak annotations are easier and cheaper to obtain . Second, by adopting late fusion , our model can quickly adapt to updated annotations with little computation overhead and thus can be applied in interactive segmentation. We leave this for future works.
We follow the evaluation scheme proposed in  and evaluate our model on the PASCAL-5i  dataset. The dataset is created from PASCAL VOC 2012  with SBD  augmentation. The 20 categories in PASCAL VOC are evenly divided into 4 splits, each containing 5 categories. Models are trained on 3 splits and evaluated on the rest one in a cross-validation fashion. The categories in each split can be found in . During testing, previous methods randomly sample 1,000 episodes for evaluation but we find it is not enough to give stable results. In our experiments, we average the results from 5 runs with different random seeds, each run containing 1,000 episodes.
Following , we also evaluate our model on a more challenging dataset built from MS COCO . Similarly, the 80 object classes in MS COCO are evenly divided into 4 splits, each containing 20 classes. We follow the same scheme for training and testing as on the PASCAL-5i. is used for all experiments.
We adopt two metrics for model evaluation, mean-IoU and binary-IoU. Mean-IoU measures the Intersection-over-Union (IoU) for each foreground class and averages over all the classes [21, 28]. Binary-IoU treats all object categories as one foreground class and averages the IoU of foreground and background [17, 4, 8]. We mainly use the mean-IoU metric because it considers the differences between foreground categories and therefore more accurately reflects the model performance. Results w.r.t. the binary-IoU are also reported for clear comparisons with some previous methods.
We initialize the VGG-16 network with the weights pre-trained on ILSVRC  as in previous works [21, 4, 28]. Input images are resized to (417, 417) and augmented using random horizontal flipping. The model is trained end-to-end by SGD with the momentum of 0.9 for 30,000 iterations. The learning rate is initialized to 1e-3 and reduced by 0.1 every 10,000 iterations. The weight decay is 0.0005 and the batch size is 1.
We set a baseline model which is initialized with the weights pre-trained on ILSVRC  but not further trained on PASCAL-5i, denoted as PANet-init. We also compare our PANet with two baseline models FG-BG and fine-tuning from . FG-BG trains a foreground-background segmentor which is independent of the support and fine-tuning is used to tune a pre-trained foreground-background segmentor on the support.
4.2 Comparison with state-of-the-arts
Table 1 compares our model with other methods on PASCAL-5i dataset in mean-IoU metric. Our model outperforms the state-of-the-art methods in both 1-shot and 5-shot settings while using fewer parameters. In the 5-shot task, our model achieves significant improvement of 8.6%. Using binary-IoU metric, as shown in Table 2, our model also achieves the highest performance. It is worth noting that our method does not use any decoder module or post-processing techniques to refine the results.
As Tables 1 and 2 show, the performance gap between 1-shot and 5-shot settings is small in other methods (less than 3.1% in mean-IoU), implying these methods obtain little improvement with more support information. In contrast, our model yields much more significant performance gain (up to 7.6% in mean-IoU) since it learns more effectively from the support set. The evaluation results of our baseline model PANet-init also confirm this point. Without training, it rivals the state-of-the-art in 5-shot settings and gains more than 11% in mean-IoU when given more support images.
As in [4, 28], we evaluate our model on multi-way few-shot segmentation tasks. Without loss of generality, we perform evaluations on 2-way 1-shot and 2-way 5-shot segmentation tasks. Table 3 summarizes the results. Our PANet outperforms previous works by a large margin of more than 20%.
Qualitative results for 1-way and 2-way segmentation are shown in Figure 3 and Figure 4. Without any decoder structure or post-processing, our model gives satisfying segmentation results on unseen classes with only one annotated support image. This demonstrates the strong learning and generalization abilities of our model. Note that the prototype extracted from the same support image can be used to successfully segment the query images with appearance variations. For example, in Figure 3 row 1, our model successfully segments bicycles: cluttered with other objects (1st example), viewed from a different perspective (2nd example), with only parts shown (3rd example). On the other hand, prototypes extracted from one part of the object can be used to segment whole objects of the same class (row 2 in Figure 3). It demonstrates that the proposed PANet is capable of extracting robust prototypes for each semantic class from a few annotated data. More qualitative examples can be found in the supplementary material.
We also present some challenging cases that fail our model. As the first failure case in Figure 3 shows, our model tends to give segmentation results with unnatural patches, possibly because it predicts independently at each location. But this can be alleviated by post-processing. From the second failure case, we find our model is unable to distinguish between chairs and tables since they have similar prototypes in the embedding space.
Table 4 shows the evaluation results on MS COCO dataset. Our model outperforms the previous A-MCG  by 7.2% in 1-shot setting and 8.2% in 5-shot setting. Compared to PASCAL VOC, MS COCO has more object categories, making the differences between two evaluation metrics more significant. Qualitative results on MS COCO are shown in Figure 3.
4.3 Analysis on PAR
The proposed PAR encourages the model to learn a consistent embedding space which aligns the support and query prototypes. Apart from minimizing the distances between the support and query prototypes, the models trained with PAR get better results (shown in Table 5) as well as faster convergence of the training process.
Aligning embedding prototypes
By flowing the information from the query set back to the support set via PAR, our model can learn a consistent embedding space and align the prototypes extracted from the support and query set. To verify this, we randomly choose 1,000 episodes from PASCAL-5i split-1 in the 1-way 5-shot task. Then for each episode we calculate the Euclidean distance between prototypes extracted from the query set and the support set. The averaged distance computed by models with PAR is 32.2, much smaller than 42.6 by models without PAR. With PAR, our model is able to extract prototypes that are better aligned in the embedding space.
|PANet w/o PAR||47.2||54.9|
Speeding up convergence
In our experiments, we observe that models trained with PAR converge faster than models without it, as reflected from the training loss curve in Figure 5. This shows the PAR accelerates convergence and helps the model reach a lower loss, especially in 5-shot setting, because with PAR the information from the support set can be better exploited.
4.4 Test with weak annotations
We further evaluate our model with scribble and bounding box annotations. During testing, the pixel-level annotations of the support set are replaced by scribbles or bounding boxes which are generated from the dense segmentation masks automatically. Each bounding box is obtained from one randomly chosen instance mask in each support image. As Table 6 shows, our model works pretty well with very sparse annotations and is robust to the noise brought by the bounding box. In 1-shot learning case, the model performs comparably well with two different annotations, but for 5-shot learning, using scribbles outperforms using bounding box by 2%. A possible reason is with more support information, scribbles give more representative prototypes while bounding boxes introduce more noise. Qualitative results of using scribble and bounding box annotations are shown in Figure 6.
We propose a novel PANet for few-shot segmentation based on metric learning. PANet is able to extract robust prototypes from the support set and performs segmentation using non-parametric distance calculation. With the proposed PAR, our model can further exploit the support information to assist training. Without any decoder structure or post-processing step, our PANet outperforms previous work by a large margin.
Jiashi Feng was partially supported by NUS IDS R-263-000-C67-646, ECRA R-263-000-C87-133 and MOE Tier-II R-263-000-D17-112.
- (2017) Segnet: a deep convolutional encoder-decoder architecture for image segmentation. IEEE transactions on pattern analysis and machine intelligence 39 (12), pp. 2481–2495. Cited by: §1, §2.
- (2018) Deeplab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE transactions on pattern analysis and machine intelligence 40 (4), pp. 834–848. Cited by: §1, §2.
- (2015) Boxsup: exploiting bounding boxes to supervise convolutional networks for semantic segmentation. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1635–1643. Cited by: §1.
- (2018) Few-shot semantic segmentation with prototype learning. In BMVC, Vol. 3, pp. 4. Cited by: §1, §2, §3.3, §4.1, §4.1, §4.2, Table 2, Table 3.
- (2010) The pascal visual object classes (voc) challenge. International journal of computer vision 88 (2), pp. 303–338. Cited by: §4.1.
- (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1126–1135. Cited by: §1, §2.
- (2011) Semantic contours from inverse detectors. Cited by: §4.1.
- (2018) Attention-based multi-context guiding for few-shot semantic segmentation. Cited by: §1, §2, §3.3, §4.1, §4.1, §4.2, Table 2, Table 4.
- (2016) Scribblesup: scribble-supervised convolutional networks for semantic segmentation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3159–3167. Cited by: §1, §3.6.
- (2017) Refinenet: multi-path refinement networks for high-resolution semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1925–1934. Cited by: §2.
- (2014) Microsoft coco: common objects in context. In European conference on computer vision, pp. 740–755. Cited by: §4.1.
- (2018) LEARNING to propagate labels: transductive propagation network for few-shot learning. Cited by: §1, §2.
- (2015) Fully convolutional networks for semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3431–3440. Cited by: §1, §2.
- (2018) TADAM: task dependent adaptive metric for improved few-shot learning. In Advances in Neural Information Processing Systems, pp. 719–729. Cited by: §1, §3.4.
- (2015) Weakly-and semi-supervised learning of a deep convolutional network for semantic image segmentation. In Proceedings of the IEEE international conference on computer vision, pp. 1742–1750. Cited by: §1.
- (2018) Few-shot segmentation propagation with guided networks. arXiv preprint arXiv:1806.07373. Cited by: §1, §2, §3.3, §3.6.
- (2018) Conditional networks for few-shot semantic segmentation. Cited by: §2, §3.3, Table 1, §4.1, §4.1, Table 2.
- (2016) Optimization as a model for few-shot learning. Cited by: §1, §2.
- (2015) ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV) 115 (3), pp. 211–252. External Links: Cited by: §4.1, §4.1.
- (2018) Few-shot learning with graph neural networks. In International Conference on Learning Representations, External Links: Cited by: §1, §2.
- (2017) One-shot learning for semantic segmentation. arXiv preprint arXiv:1709.03410. Cited by: §1, §2, §3.1, §3.3, Table 1, §4.1, §4.1, §4.1, Table 2.
- (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §3.2.
- (2017) Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087. Cited by: §1, §2, §2, §3.3, §3.4.
- (2018) Learning to compare: relation network for few-shot learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1199–1208. Cited by: §1, §2.
- (2016) Matching networks for one shot learning. In Advances in neural information processing systems, pp. 3630–3638. Cited by: §1, §2.
- (2017) Object region mining with adversarial erasing: a simple classification to semantic segmentation approach. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1568–1576. Cited by: §1.
- (2015) Multi-scale context aggregation by dilated convolutions. arXiv preprint arXiv:1511.07122. Cited by: §2.
- (2018) SG-one: similarity guidance network for one-shot semantic segmentation. arXiv preprint arXiv:1810.09091. Cited by: §1, §2, §3.3, Table 1, §4.1, §4.1, §4.2, Table 2.
- (2017) Pyramid scene parsing network. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2881–2890. Cited by: §1, §2.