An Iterative Approach for Multiple Instance Learning Problems
Multiple Instance learning (MIL) algorithms are tasked with learning how to associate sets of elements with specific set-level outputs. Towards this goal, the main challenge of MIL lies in modelling the underlying structure that characterizes sets of elements. Existing methods addressing MIL problems are usually tailored to address either: a specific underlying set structure; specific prediction tasks, e.g. classification, regression; or a combination of both. Here we present an approach where a set representation is learned, iteratively, by looking at the constituent elements of each set one at a time. The iterative analysis of set elements enables our approach with the capability to update the set representation so that it reflects whether relevant elements have been detected and whether the underlying structure has been matched. These features provide our method with some model explanation capabilities. Despite its simplicity, the proposed approach not only effectively models different types of underlying set structures, but it is also capable of handling both classification and regression tasks – all this while requiring minimal modifications. An extensive empirical evaluation shows that the proposed method is able to reach and surpass the state-of-the-art.
Traditional single-instance classification methods focus on learning a mapping between a feature vector (extracted from a single instance) w.r.t. a specific class label of interest. In a complementary fashion, Multiple Instance Learning (MIL) [MILencyclopedia] algorithms are tasked with learning how to associate a set of elements, usually referred to as a ”bag”, with a specific label. In comparison, MIL methods usually require weaker supervision in the form of set-level labels. The capability of making predictions over groups of elements while requiring weak supervision is a characteristic that makes this family of methods attractive to address several real-world applications. Examples include drug activity prediction, image classification [Carbonneau2018MultipleIL, MIimagesWei2016], image retrieval, sound classification [MIsoundClassification], anomaly detection [MITanomalyDetection], medical imaging [MILbreastCancer] and web-mining.
Performing predictions at the level of sets of elements introduces several challenges. On the one hand, the label of a set can be defined as a function of specific instance-level characteristics of the elements that compose it. On the other hand, this set label can be defined as a function of relationships occurring between the constituent elements. Since several relationships are possible between the elements, this later scenario constitutes a more challenging problem. As can be found in the MIL literature [AmoresMILsurvey2013, Carbonneau2018MultipleIL, foulds_frank_2010], several algorithms have been proposed during the last decade tailored to address specific tasks/goals, formulated as an MIL problem. Each of these tasks possessing specific set characteristics, usually referred to as the ”Multiple Instance (MI) assumptions”. The most common of these assumptions, i.e. the standard MI assumption, states that a set is positive if it contains at least one positive element (usually called witness), otherwise the set is negative. Here we propose a general method is able to go beyond this standard assumption model other possible underlying MI structures.
The proposed method follows a two-step iterative set pooling approach to address MIL problems. Given a set of elements, in a first step, each of the elements is encoded by an Instance Descriptor Unit. Then, in a second step, each of the feature-encoded elements is passed to the Iterative Set Pooling Unit. This unit is tasked with embedding and iteratively aggregating all the information from the different elements into a set-level representation (Fig. 1). Our approach is flexible in the sense that it is capable of characterizing sets of elements, i.e. modeling the underlying MI assumption, by iteratively looking at each of the elements that compose the set one at a time. Our extensive evaluation shows strengths of our method at three fronts. First, it is able to model several MI assumptions, e.g. single or multiple witness element detection, counting or collective assumptions. Second, it can address several type of prediction problems, i.e. classification and regression, Third, it obtains competitive or superior performance w.r.t. the state-of-the-art. All of this while requiring minimal modifications. Moreover, a deeper analysis of our iterative set pooling unit suggests that it is capable of highlighting, internally, the element (or group of elements) that triggers the given prediction, thus, possessing explainability capabilities.
Our contributions are three-fold: i) a novel iterative method powered with feedback mechanisms that is capable of modelling the underlying assumption / relationship that characterizes the elements in a set without the need of explicit heuristics, ii) a robust framework able to handle various typical assumptions considered by MIL problems, and iii) an approach to address MIL problems powered with explainability capabilities.
2 Related Work
Over the last decade various approaches have been proposed to address several types of MIL problems. Since our work is based on deep neural networks, we position our approach w.r.t. efforts based on neural networks, specifically those with deep architectures. Please refer to [AmoresMILsurvey2013, Carbonneau2018MultipleIL, foulds_frank_2010] for detailed surveys covering non-deep methods.
[RamonMINN] constitutes one of the first efforts towards addressing MIL problems through neural networks. The proposed multiple instance neural network (MINN) estimates instance probabilities which are aggregated at the last layer using a convex max operator in order to predict a set probability. This idea is further extended in [WangDeepMIL] which uses a neural network to learn a set representation and directly carry out set classification without estimating instance-level probabilities or labels. In parallel, [ITW:2018] proposed an attention mechanism to learn a pooling operation over instances. The weights learned for the attention mechanism on instances can serve as indicators of the contribution of each instance to the final decision – thus, producing explainable predictions. [LiuCVPR17attention] proposed a similar idea, using the computed set representations, to measure distances between image sets. [yan18DynamicPooling] proposed to update the contributions of the instances by observing all the instances of the set a predefined number of iterations. Along a different direction, [tiboMMIN] proposed a hierarchical set representation in which each set is internally divided into subsets until reaching the instance level. Very recently, [MingMIGNN] proposed to consider the elements in the sets to be non-i.i.d. and used graph neural networks to learn a set embedding.
Similar to [ITW:2018, WangDeepMIL] we embed the instance features from each set into a common space from which a set representation is learned. This set representation is used to make directly set predictions related to MIL problems. Similar to [MingMIGNN] and [yan18DynamicPooling] we aim at learning the underlying structure within the sets. Different from [MingMIGNN], our method does not rely on hand-tuned parameters, e.g. distance thresholds to define edges in the graph, and other manual graph construction. Moreover, the improvement in performance displayed by our method is not sensitive to the possible lack of structure within each set. Compared to [yan18DynamicPooling], our method only requires a single pass through all the instances. Moreover, our method is able to go beyond binary classification tasks and handle more complex classification and regression tasks. Finally, most of the works mentioned above operate under the standard MI learning assumption. In contrast, the proposed approach is able to learn the underlying structure of set of instances, thus, being robust to several MIL assumptions/problems [foulds_frank_2010]..
3 Multiple Instance Assumptions
Before describing the proposed approach, we introduce set characteristics or assumptions that have been commonly considered in order to define set-level labels. While different surveys [AmoresMILsurvey2013, Carbonneau2018MultipleIL, foulds_frank_2010] have grouped these assumptions based on different criteria, we focus on the following general assumptions, which can be adapted to meet more specific ones.
3.1 Standard Multiple Instance Assumption
Given the set of instances with latent instance-level labels , traditional MIL problems aim at the prediction of binary set-level labels for each set . Under the standard MI assumption, a set is positive, i.e. , if and only if, at least one of the elements/instances that compose it satisfies a pre-defined desired property .
3.2 Collective Multiple Instance Assumption
Under the standard MI assumption only a small subset of elements, possessing a given property, contribute to the prediction of the set label. In contrast, under the collective MI assumption, all the elements contribute equally to the predicted set label.
where is a threshold value and the set-level score is computed from the contributions of each of the elements as follows:
An alternative point is defined by the collective weighted MI assumption under which different elements may have different degrees of influence in the label of the set as determined by their weight . Under this weighted assumption, the set score is computed as:
with the normalization term .
3.3 Rank-based MI Assumption
This assumption assumes that there exists a property for every element in the set under which the elements, or a subset of them, can be ranked on a fixed order. Taking the property into account, a set is considered positive if its elements follow a specific order (as in sequential data). Otherwise it is considered negative. More formally:
This type of assumption is related to problems processing sequential data, e.g. action recognition [DBLP:SharmaKS15] or language modeling [DBLP:conf/interspeech/SundermeyerSN12].
4 Proposed Method
The proposed approach consists of three main components. Given a set of elements , each of the elements is encoded into a feature representation through the Instance Description Unit (Section 4.1). Then, each element is fed to the Iterative Set Pooling Unit (Section 4.2), producing the aggregated set representation . Finally, a prediction is obtained by evaluating the set representation via the Prediction Unit (Section 4.3),
4.1 Instance Description Unit
This component receives the set elements in raw form, i.e. each of the instances that compose it, in its original format. It is tasked with encoding the input set data into a format that can be processed by the rest of the pipeline. As such, it provides the proposed method with robustness to different data formats/modalities. More formally, given a dataset of sets paired with their corresponding set-level labels , each of the sets is encoded into a feature . This is achieved by pushing each of the elements that compose it, through a feature encoder producing the instance-level representation .
Selection of this component depends on the modality of the data to be processed, e.g. VGG [VGG] or ResNet [He2015DeepRL] features for still images, Word2Vec [MikolovWord2Vec] or TF-IDF for text data or rank-pooled features [FernandoAl:TPAMI16] or dynamic images [bilenDynamicImages] for video data.
4.2 Iterative Set Pooling Unit
The main goal of this component is to derive a set-level representation that is able to encode all the elements , and any possible underlying structure between them. We aim at learning a set representation that is independent of both the cardinality of the set and the nature of the underlying structure. Starting from the element-level representations computed in the previous step, this is achieved by iteratively looking at the representations , from each of the elements , one at a time. In each iteration an updated set-level representation is computed. In parallel, a feedback loop provides information regarding the state of the set representation that will be considered at the next interation . Finally, after observing all the elements in the set, the final set representation is taken as the output of this component.
The notion behind this iterative set pooling idea is that elements observed a specific iterations can be used to compute a more-informed set-level representation at later iterations. Thus, allowing to encode underlying relationships or structures among the elements of the set. While this iterative assumption may hint at a sequence structure requirement within each set, our empirical evaluation strongly suggests this not to be the case. Moreover, this provides the proposed approach with robustness towards sets possessing a sequence-like structure, while not enforcing the requirement of the existence of such a structure.
In practice, this iterative mechanism can be implemented through Recurrent Neural Networks [RNN_Schuster_97], Long Short Term Memory (LSTM) Networks [LSTM_97], Gated Recurrent Units [choGRU], or any other machinery with means to allow information persistence across multiple observations . Here, we implement this component through LSTMs given their robustness of modeling structures within a set with high cardinality. This will ensure that the learned set representation can encode structures between all the elements in the set, independently of the cardinality of the set. More specifically, we use Bi-directional LSTMs which observe the elements in a set from the left-to-right and right-to-left directions. This will further ensure that the context in which the elements of the set occur is properly modelled.
4.3 Prediction Unit
Having a set-level representation for set , this component is tasked with making a set label prediction that will serve as final output for the pipeline. The selection of the prediction function is related to the task of interest. This unit provides our method with flexibility to address both classification and regression prediction tasks.
4.4 Explaining Model Predictions
Up to this point, we have presented an iterative method to make predictions from a set-level representation through the use of a prediction function . While being able to make accurate predictions is of importance, being able to provide an explanation supporting the prediction made is a desirable property for any automatic system. In MIL algorithms, these explanations usually come in the form of highlighting the elements or instances of the set which determine the predicted set label .
In the proposed approach this can be achieved by probing the set representation after each of the elements are embedded on it. More specifically, on an initial step we can push every element through the set pooling unit and store the set representation computed after the embedding of the element. Then, the relevant elements can be highlighted by identifying the elements with strong effect in the computed set representation . Finally, the selection of elements can be further verified, by the response that their corresponding set-level representations produce when evaluated by the prediction unit.
5 Experimental Evaluation
We conduct a series of experiments to assess empirically the performance of the proposed approach under different data modalities and considering different MI assumptions.
5.1 Drug Activity Prediction
First, we conduct experiments on the Drug Activity Prediction task proposed in [foulds_frank_2010]. This is a standard benchmark used to assess the performance of MIL methods. It is composed of two sets, i.e., MUSK1 and MUSK2 which consist of 47/45 and 39/63 positive/negative sets, respectively. Elements within a set correspond to different conformations of a molecule, with each conformation being described by a 166-dimensional vector. The task is to predict whether new molecules will be musks or non-musks.
Discussion: A quick glance at Table1 shows that the proposed method has comparable performance as state-of-the-art methods. More specifically, on MUSK1 it surpasses the closest method by percentage points (pp). For the case of MUSK2, its performance is on par with most of the competitors, with the exception of the Dynamic Pooling method which achieves superior performance at the cost of a larger number of computations. Overall the proposed method achieves classification accuracy on this task.
|Atten.Based||89.2 4.0||85.8 4.8|
|Gated Atten. Based||90.0 5.0||85.8 4.8|
|Dyn. Pool||90.7 3.6||92.6 4.3|
|Ours||93.2 6.0||85.4 5.5|
5.2 MI Predictions on Simplified Visual Data
This experiment focuses on performing MI predictions based on visual data. Following the protocol from [ITW:2018] we use images from the MNIST dataset [lecun-MNIST-2010] to construct image sets to define four scenarios, each following a different assumption: Single digit occurrence, Multiple digit occurrence, Digit sequences and Single digit counting. For each scenario we sample images from MNIST to construct 500 image sets for training and 200 sets for testing. Label balance is preserved within each data split.
For this series of experiments, we use a LeNet111Please refer to the supplementary material for more details. [lenet] as instance descriptor unit and a LSTM with an input and cell state with 500 dimensions, respectively. We compare the obtained performance w.r.t. the attention-based model from [ITW:2018] and the dynamic pooling method from [yan18DynamicPooling]. Mean error rate in the binary classification task is adopted as performance metric in these experiments.
|Method||single digit()||single digit()||multiple digits||digit sequence||digit counting|
|Atten. Based||2.8 4.8||4.5 0.4||28.5 0.7||47.3 3.2||33.4 19.3|
|Gated Atten. Based||4.0 0.9||4.6 0.5||27.4 0.9||47.0 2.9||11.9 3.6|
|Dyn. Pool||5.6 1.1||6.1 1.2||28.5 6.6||47.9 2.5||25.4 1.8|
|Ours||3.5 1.1||3.1 0.5||6.4 1.4||2.8 0.7||9.0 2.7|
Single Digit Occurrence
In this scenario we follow the standard MI assumption and label a set as positive if at least one digit ’9’ occurs in the set. The digit ’9’ is selected since it can be easily mistaken with digit ’4’ and ’7’ [ITW:2018], thus, introducing some element-level ambiguity. We define sets with mean cardinality , and verify the effect that has on performance by testing two standard deviation values, and . We repeat this experiment five times generating different sets and weight initializations. We report mean performance in Table 2 (column II and III).
Discussion: The results indicate that, in this task, our performance is comparable with the state-of-the-art for lower values of and superior as increases. This is to some extent expected, since at lower the cardinality (i.e. the number of elements) of each set is almost fixed. This setting is favorable for the attention-based method since it operates in a feed-forward fashion. Yet, note the high standard deviation in performance produced by this baseline. On the contrary, at higher values there is a higher variation of cardinality across sets. Under this setting, feed-forward approaches start to produce higher errors. Here our method produces superior performance, percentage points (pp) w.r.t. to the state-of-the-art.
Multiple Digit Occurrence
This is an extension of the previous scenario in which instead of focusing on the occurrence of a single digit class, the model should recognize the occurrence of instances of two digit classes. More specifically, a set is labeled positive if both digits ’3’ and ’6’ occur in it, without considering the order of occurrence. For this scenario 1,000 sets are sampled for training. Results are reported in Table 2 (column IV).
Discussion: It is remarkable that when making this simple extension of considering the occurrence of multiple digits, i.e. ’3’ and ’6’, the state-of-the-art methods suffer a significant drop in performance. This drop put the state-of-the-art methods pp below, on average, w.r.t. the performance of our method. Please note that in this experiment the order (or location) of the two digits does not matter. This suggests that the proposed iterative set pooling unit can handle multiple elements of interest, independent of the ordering in which they occur within the sets. Compared to the Single digit occurrence in this scenario, where observing multiple elements is of interest, the model needs to “remember” the information that it has seen in order to asses whether instances of the classes of interest have been encountered. The feed-forward models lack information persistence mechanisms; which translates to a poor ability to remember and to handle multiple elements of interest. Surprisingly, in spite of its iterative nature, the Dynamic pooling method is not able to preserve the information it has observed across iterations, resulting in similar performance as the other baselines.
Similar to the previous setting, in this scenario multiple elements are of interest within each set, however, the order of occurrence of these do matter. More specifically, a set is labeled positive if an instance of digit ’3’ occurs earlier, i.e. it has a lower index in the set, than one of digit ’6’. This scenario follows the Rank-based MI assumption presented in Section 3. Quantitative results are reported in Table 2 (column V).
Discussion: As can be seen Table 2, under this scenario, the proposed method leads the performance table by a large margin of pp. This is to some extent expected since the LSTM network used to implement our iterative set pooling unit is designed to handle sets whose instances posses an underlying sequential structure.
Previous scenarios addressed the classification task of predicting positive/negative set-level labels. In contrast, in this scenario, we focus on the regression task of counting the number of instances of a specific digit class of interest within the set. In order to make our approach suitable to address a regression problem, instead of using a classifier as prediction unit we use a regressor whose continuous output is rounded in order to provide a discrete count value as output. In this experiment the digit ’9’ is selected as the class to be counted. The mean cardinality of each set is fixed to . Performance is reported in Table 2 (column IV).
Discussion: From Table 2 (column VI) the same trend can be observed: our method has superior performance and higher stability than the attention-based model. When conducting this counting task, our method obtains a performance that is superior by pp w.r.t. the attention-based model and by pp w.r.t. the dynamic pooling. These results support the capability of the method to handle regression problems.
5.3 MI Predictions on Realistic Visual Data
We complement experiments from Section 5.2 by considering sets composed by more complex visual data. Towards this goal, we use images from a fashion-related dataset, namely Lookbook [Lookbook]. Images in this dataset are divided into two domains: catalog clothing images and their corresponding human model images where a person is wearing the clothing product. Each clothing product has one catalog image and several human model images. We only consider the products with five or more human model images, resulting in 6616 unique products (latent classes ) with around 63k images in total. Every product image has 5-55 human model images. The training set contains 4000 classes while the validation and test sets have 616 and 2000 classes, respectively. We run two experiments on this dataset as described in the following sections.
Given the higher complexity of images in this dataset, we use a VGG16222Please refer to the supplementary material for more details. [VGG] as Instance Description Unit. Moreover, for the iterative set pooling unit, we set the dimensionality of the input and cell state of our LSTM to .
This is a binary classification task where the goal is to indicate whether a given set of images contains an outlier image. Image sets in this experiment are composed exclusively by human model images of the same clothing product (class). This experiment follows an inverse version of the standard MI assumption where sets are considered positive, i.e without outliers, if all its constituent images (elements) belong to the same clothing product. Otherwise, if one of the images in the set does not belong to the clothing product, the set is considered negative, i.e. with an outlier. In this experiment all the sets have 5 images.
|Atten. Based||14.0 0.9|
|Dyn. Pool||42.1 0.3|
Discussion: The result in Table 3 indicates that our method still achieves superior performance in a realistic dataset. More concretely, our method produces a mean error that is 3 pp lower w.r.t. attention-based method and significant 31 pp lower w.r.t. the dynamic pooling baseline. It is noteworthy that while somewhat related to the standard MI assumption this is a harder setting. Here the “witness element” that defines the sets with outliers has very high variance, i.e. it can be any element with a different class w.r.t. those in the set. Compared to the single or multiple digit occurrence experiments, where the model only has to be aware of the specific fixed digit(s), in this task the outlier can be any image, which means the models should understand every element in the set.
Cross-domain clothing retrieval
For this experiment, human model images are used as queries while catalog images serve as database, thus, defining a many-to-one retrieval. The cardinality of each set is the same as the number of human model images of each product (class). We conduct two variants of this experiment. On the first variant we use the complete image, as it is originally provided. The second is an occluded variant where every human model image in a set is divided into a grid of 16 blocks. 12 of these blocks are occluded by setting all the pixels therein to black. By doing so, every single image in a set can only show part of the information while their combination (i.e. the whole set) represents the complete clothing item. Catalog images in the database are not occluded in this experiment.
As baselines, in addition to the attention-based model we follow DeepFashion [DeepFashion], and train a model to perform retrieval by computing the distances by considering single image representations instead of set-based representations. Following the multiple queries approach from [multipleQueries], we report performance of three variants of this method: Single-AVE, where the distance of each set is computed as the average of the distances from every image in the set w.r.t. an item in the database; Single-MIN, where the distance of the set is defined as the minimum distance of an image in the set w.r.t. an item in the database; and Single Fea. AVE , where the distance of the set is calculated as the distance of a prototype element w.r.t. an item in the database. As prototype element we use the average feature representation of from the representation of every element in the set. We refer to these baselines as Single-image models.
This retrieval task is to some extent related to the collective MI assumption (Sec. 3) since all the elements in the set contribute to the task handled by the model.
Discussion: Table 4 shows that in the original setting our method tends to obtain superior recall values in the majority of the cases, with the exception of the case when the closest 20 items (recall@20) are considered. When looking at the occluded variant of the experiment, a quick glance at Table 5 shows that, compared to the original setting, absolute performance values on this setting are much lower. This is to be expected since this is a more challenging scenario where the model needs to learn the information cumulatively by aggregating information from parts of different images. In this occluded setting, our method clearly outperforms the attention-based, dynamic pooling and the methods based on single-image distances. This could be attributed to the information persistence component that is part of our method. This component allows our method to select what to remember and what to ignore from each of the elements that it observes when updating the set representation used to compute distances. The difference w.r.t. to the Single-AVE and Single-MIN baselines is quite remarkable given that they require a significant larger number of element-wise distance computations w.r.t. items in the database. This may lead to scalability issues when the dataset size increases, as the computation cost will grow exponentially.
Moreover, in both occluded and non-occluded datasets, we notice that the Single-image model baselines have a superior performance w.r.t. the attention-based model and dynamic pooling model. We hypothesize that is because the single-image models can better exploit important features, e.g. discriminative visual patches, since they compute distances directly in an element-wise fashion. In contrast, it is likely that some of these nuances might get averaged out by the feature aggregation step that is present in the attention-based model.
|Single Fea. AVE||20.15||56.25||67.85||81.50|
|Single Fea. AVE||5.10||25.60||36.95||54.65|
5.4 Explaining Model Predictions
In this section we analyze the explanation capabilities of our method. Towards this goal, in Fig. 4 we show the predicted output after observing each element of the set. Since the set pooling unit utilizes a Bi-LSTM, which processes forward and backward directions of the set together, we show the two directions of the set. In addition, we verify the capabilities of the proposed set representation to encode the underlying MI assumption. This could be indicated by reflecting significant variations in the when observing the elements involved in the MI assumption. We ease the visualization of the high-dimensional representation by plotting its corresponding t-SNE [tsne] projection in Fig. 3. See the supp. material for more examples.
Discussion: In Fig. 4, we can notice that each time one of the elements that determine the MI assumption are observed, the set representation is updated in such a way that there is a significant change in the prediction made by the model. This is further supported by the state of the internal representation as shown by the corresponding t-SNE visualizations (Fig. 3). For the Single Digit Occurrence case, we notice that the representation gets updated to a different region of the space when the digit of interest is observed. More specifically, from the third row of Fig. 3, it is clear that the space is divided into two parts: the set representation of negative sets changes within the bottom-left region, while for positive sets, once the digit of interest occurs, the set representation jumps to the top-right region and ends there. Similarly, for the Multiple Digit Occurrence, Digit Sequences and Digit Counting cases, we notice that the representation shifts, significatively, to specific regions (green and magenta dots) every time one of the digits of interest is observed. Moreover, for Multiple Digit Occurrence and Digit Sequences the representation seems to always reach a common region once the underlying MI assumption has been completely satisfied.
We presented an iterative approach to address MIL problems. Our method is capable of learning the underlying structure that characterizes each of the sets by looking at its constituent elements one at a time. Despite its simplicity the proposed method is able to effectively model a variety of underlying MI assumptions and handle both classification and regression task while requiring minimum modifications. A deeper analysis of the learned set representation reveals that our method is able to highlight witness elements which are relevant to the underlying set structure, thus, providing some explanation capabilities.