Prediction-Constrained Training forSemi-Supervised Mixture and Topic Models

Prediction-Constrained Training for
Semi-Supervised Mixture and Topic Models

Michael C. Hughes   Contact email: School of Engineering and Applied Sciences, Harvard University Leah Weiner Dept. of Computer Science, Brown University Gabriel Hope School of Information and Computer Sciences, Univ. of California, Irvine Thomas H. McCoy, Jr. Massachusetts General Hospital Roy H. Perlis Massachusetts General Hospital Erik B. Sudderth School of Information and Computer Sciences, Univ. of California, Irvine Finale Doshi-Velez School of Engineering and Applied Sciences, Harvard University

Supervisory signals have the potential to make low-dimensional data representations, like those learned by mixture and topic models, more interpretable and useful. We propose a framework for training latent variable models that explicitly balances two goals: recovery of faithful generative explanations of high-dimensional data, and accurate prediction of associated semantic labels. Existing approaches fail to achieve these goals due to an incomplete treatment of a fundamental asymmetry: the intended application is always predicting labels from data, not data from labels. Our prediction-constrained objective for training generative models coherently integrates loss-based supervisory signals while enabling effective semi-supervised learning from partially labeled data. We derive learning algorithms for semi-supervised mixture and topic models using stochastic gradient descent with automatic differentiation. We demonstrate improved prediction quality compared to several previous supervised topic models, achieving predictions competitive with high-dimensional logistic regression on text sentiment analysis and electronic health records tasks while simultaneously learning interpretable topics.


Prediction-Constrained Training for
Semi-Supervised Mixture and Topic Models



noticebox[b]Unpublished preprint, last updated July 3, 2019.\end@float

1 Introduction

Latent variable models are widely used to explain high-dimensional data by learning appropriate low-dimensional structure. For example, a model of online restaurant reviews might describe a single user’s long plain text as a blend of terms describing customer service and terms related to Italian cuisine. When modeling electronic health records, a single patient’s high-dimensional medical history of lab results and diagnostic reports might be described as a classic instance of juvenile diabetes. Crucially, we often wish to discover a faithful low-dimensional representation rather than rely on restrictive predefined representations. Latent variable models (LVMs), including mixture models and topic models like Latent Dirichlet Allocation (Blei et al., 2003), are widely used for unsupervised learning from high-dimensional data. There have been many efforts to generalize these methods to supervised applications in which observations are accompanied by target values, especially when we seek to predict these targets from future examples. For example, Paul and Dredze (2012) use topics from Twitter to model trends in flu, and Jiang et al. (2015) use topics from image captions to make travel recommendations. By smartly capturing the joint distribution of input data and targets, supervised LVMs may lead to predictions that better generalize from limited training data. Unfortunately, many previous methods for the supervised learning of LVMs fail to deliver on this promise—in this work, our first contribution is to provide theoretical and empirical explanation that exposes fundamental problems in these prior formulations.

One naïve application of LVMs like topic models to supervised tasks uses two-stage training: first train an unsupervised model, and then train a supervised predictor given the fixed latent representation from stage one. Unfortunately, this two-stage pipeline often fails to produce high-quality predictions, especially when the raw data features are not carefully engineered and contain structure irrelevant for prediction. For example, applying LDA to clinical records might find topics about common conditions like diabetes or heart disease, which may be irrelevant if the ultimate supervised task is predicting sleep therapy outcomes.

Because this two-stage approach is often unsatisfactory, many attempts have been made to directly incorporate supervised labels as observations in a single generative model. For mixture models, examples of supervised training are numerous (Hannah et al., 2011, Shahbaba and Neal, 2009). Similarly, many topic models have been proposed that jointly generate word counts and document labels (McAuliffe and Blei, 2007, Lacoste-Julien et al., 2009, Wang et al., 2009, Zhu et al., 2012, Chen et al., 2015). However, a survey by Halpern et al. (2012) finds that these approaches have little benefit, if any, over standard unsupervised LDA in clinical prediction tasks. Furthermore, often the quality of supervised topic models does not significantly improve as model capacity (the number of topics) increases, even when large training datasets are available.

In this work, we expose and correct several deficiencies in previous formulations of supervised topic models. We introduce a learning objective that directly enforces the intuitive goal of representing the data in a way that enables accurate downstream predictions. Our objective acknowledges the inherent asymmetry of prediction tasks: a clinician is interested in predicting sleep outcomes given medical records, but not medical records given sleep outcomes. Approaches like supervised LDA (sLDA, McAuliffe and Blei (2007)) that optimize the joint likelihood of labels and words ignore this crucial asymmetry. Our prediction-constrained latent variable models are tuned to maximize the marginal likelihood of the observed data, subject to the constraint that prediction accuracy (formalized as the conditional probability of labels given data) exceeds some target threshold.

We emphasize that our approach seeks to find a compromise between two distinct goals: build a reasonable density model of observed data while making high-quality predictions of some target values given that data. If we only cared about modeling the data well, we could simply ignore the target values and adapt standard frequentist or Bayesian training objectives. If we only cared about prediction performance, there are a host of discriminative regression and classification methods. However, we find that many applications benefit from the representations which LVMs provide, including the ability to explain target predictions from high-dimensional data via an interpretable low-dimensional representation. In many cases, introducing supervision enhances the interpretability of the generative model as well, as the task forces modeling effort to focus on only relevant parts of high-dimensional data. Finally, in many applications it is beneficial to have the ability to learn from observed data for which target labels are unavailable. We find that especially in this semi-supervised domain, our prediction-constrained training objectives provides clear wins over existing methods.

2 Prediction-constrained Training for Latent Variable Models

In this section, we develop a prediction-constrained training objective applicable to a broad family of latent variable models. Later sections provide concrete learning algorithms for supervised variants of mixture models (Everitt and Hand, 1981) and topic models (Blei, 2012). However, we emphasize that this framework could be applied much more broadly to allow supervised training of well-known generative models like probabilistic PCA (Roweis, 1998, Tipping and Bishop, 1999), dynamic topic models (Blei and Lafferty, 2006), latent feature models (Griffiths and Ghahramani, 2007), hidden Markov models for sequences (Rabiner and Juang, 1986) and trees (Crouse et al., 1998), linear dynamical system models (Shumway and Stoffer, 1982, Ghahramani and Hinton, 1996), stochastic block models for relational data (Wang and Wong, 1987, Kemp et al., 2006), and many more.

(a) General model (b) Supervised mixture (Sec. 3) (c) Supervised LDA (Sec. 4)
Fig. 1: Graphical models for downstream supervised LVMs amenable to prediction-constrained training.

The broad family of latent variable model we consider is illustrated in Fig. 1. We assume an observed dataset of paired observations . We refer to as data and as labels or targets, with the understanding that in intended applications, we can easily access some new data but often need to predict from . For example, the pairs may be text documents and their accompanying class labels, or images and accompanying scene categories, or patient medical histories and their accompanying diagnoses. We will often refer to each observation (indexed by ) as a document, since we are motivated in part by topic models, but we emphasize that our work is directly applicable to many other LVMs and data types.

We assume that each of the exchangeable data pairs is generated independently by the model via its own hidden variable . For a simple mixture model, is an integer indicating the associated data cluster. For more complex members of our family like topic models, may be a set of several document-specific hidden variables. The generative process for the random variables local to document unfolds in three steps: generate from some prior , generate given according to some distribution , and finally generate given both and from some distribution . The joint density for document  then factorizes as


We assume the generating distributions have parameterized probability density functions which can be easily evaluated and differentiated. The global parameters , , and specify each density. When training our model, we treat the global parameters as random variables with associated prior density .

Our chosen model family is an example of a downstream LVM: the core assumption of Eq. (1) is that the generative process produces both observed data and targets conditioned on the hidden variable . In contast, upstream models such as Dirichlet-multinomial regression (Mimno and McCallum, 2008), DiscLDA (Lacoste-Julien et al., 2009), and labeled LDA (Ramage et al., 2009) assume that observed labels are generated first, and then combined with hidden variables to produce data . For upstream models, inference is challenging when labels are missing. For example, in downstream models may be computed by omitting factors containing , while upstream models must explicitly integrate over all possible . Similarly, upstream prediction of labels from data is more complex than for downstream models. That said, our predictively constrained framework could also be used to produce novel learning algorithms for upstream LVMs.

Given this general model family, there are two core problems of interest. The first is global parameter learning: estimating values or approximate posteriors for given training data . The second is local prediction: estimating the target given data and model parameters .

2.1 Regularized Maximum Likelihood Optimization for Training Global Parameters

A classical approach to estimating would be to maximize the marginal likelihood of the training data and targets , integrating over the hidden variables . This is equivalent to minimizing the following objective function:


Here, denotes a (possibly uninformative) regularizer for the global parameters. If for some prior density function , Eq. (2) is equivalent to maximum a posteriori (MAP) estimation of .

One problem with standard ML or MAP training is that the inputs and targets are modeled in a perfectly symmetric fashion. We could equivalently concatenate and to form one larger variable, and use standard unsupervised learning methods to find a joint representation. However, because practical models are typically misspecified and only approximate the generative process of real-world data, solving this objective can lead to solutions that are not matched to the practitioner’s goals. We care much more about predicting patient mortality rates than we do about estimating past incidences of routine checkups. Especially because inputs are usually higher-dimensional than targets , conventionally trained LVMs may have poor predictive performance.

2.2 Prediction-Constrained Optimization for Training Global Parameters

As an alternative to maximizing the joint likelihood, we consider a prediction-constrained objective, where we wish to find the best possible generative model for data that meets some quality threshold for prediction of targets given . A natural quality threshold for our probabilistic model is to require that the sum of log conditional probabilities must exceed some scalar value . This leads to the following constrained optimization problem:

subject to

We emphasize that the conditional probability marginalizes the hidden variable :


This marginalization allows us to make predictions for that correctly account for our uncertainty in given , and importantly, given only . If our goal is to predict given , then we cannot train our model assuming is informed by both and .

Lagrange multiplier theory tells us that any solution of the constrained problem in Eq. (3) as also a solution to the unconstrained optimization problem


for some scalar Lagrange multiplier . For each distinct value of , the solution to Eq. (5) also solves the constrained problem in Eq. (3) for a particular threshold . While the mapping between and is monotonic, it is not constructive and lacks a simple parametric form.

We define the optimization problem in Eq. (5) to be our prediction-constrained (PC) training objective. This objective directly encodes the asymmetric relationship between data and labels by prioritizing prediction of from when . This contrasts with the joint maximum likelihood objective in Eq. (2) which treats these variables symmetrically, and (especially when is high-dimensional) may not accurately model the predictive density . In the special case where , the PC objective of Eq. (5) reduces to the ML objective of Eq. (2).

2.2.1 Extension: Constraints on a general expected loss

Penalizing aggregate log predictive probability is sensible for many problems, but for some applications other loss functions are more appropriate. More generally, we can penalize the expected loss between the true labels and predicted labels under the LVM posterior :

subject to

This more general approach allows us to incorporate classic non-probabilistic loss functions like the hinge loss or epsilon-insensitive loss, or to penalize errors asymmetrically in classification problems, when measuring the quality of predictions. However, for this paper, our algorithms and experiments focus on the probabilistic loss formulation in Eq. (5).

2.2.2 Extension: Prediction constraints for individual data items

In Eq. (3), we defined our prediction quality constraint using the sum (or equivalently, the average) of the document-specific losses . An alternative, more stringent training object would enforce separate prediction constraints for each document:

subject to

This modified optimization problem would generalize Eq. (5) by allocating a distinct Lagrange multiplier weight for each observation . Tuning these weights would require more sophisticated optimization algorithms, a topic we leave for future research.

2.2.3 Extension: Semi-supervised prediction constraints for data with missing labels

In many applications, we have a dataset of observations for which only a subset have observed labels ; the remaining labels are unobserved. For semi-supervised learning problems like this, we generalize Eq. (3) to only enforce the label prediction constraint for the documents in , so that the PC objective of Eq. (3) becomes:

subject to

In general, the value of will need to be adapted based on the amount of labeled data. In the unconstrained form


as the fraction of labeled data gets smaller, we will need a much larger Lagrange multiplier to uphold the same average quality in predictive performance. This occurs simply because as gets smaller, the data likelihood term will continue to get larger in relative magnitude compared to the label prediction term .

2.3 Relationship to Other Supervised Learning Frameworks

While the definition of the PC training objective in Eq. (5) is straightforward, it has desirable features that are not shared by other supervised training objectives for downstream LVMs. In this section we contrast the PC objective with several other approaches, often comparing to methods from the topic modeling literature to give concrete alternatives.

2.3.1 Advantages over standard joint likelihood training

For our chosen family of supervised downstream LVMs, the most standard training method is to find a point estimate of global parameters that maximizes the (regularized) joint log-likelihood as in Eq. (2). Related Bayesian methods that approximate the posterior distribution , such as variational methods (Wainwright and Jordan, 2008) and Markov chain Monte Carlo methods (Andrieu et al., 2003), estimate moments of the same joint likelihood (see Eq. (1)) relating hidden variables to data and labels .

For example, supervised LDA (McAuliffe and Blei, 2007, Wang et al., 2009) learns latent topic assignments by optimizing the joint probability of bag-of-words document representations and document labels . One of several problems with this joint likelihood objective is cardinality mismatch: the relative sizes of the random variables and can reduce predictive performance. In particular, if is a one-dimensional binary label but is a high-dimensional word count vector, the optimal solution to Eq. (2) will often be indistinguishable from the solution to the unsupervised problem of modeling the data alone. Low-dimensional labels can have neglible impact on the joint density compared to the high-dimensional words , causing learning to ignore subtle features that are critical for the prediction of from . Despite this issue, recent work continues to use this training objective (Wang and Zhu, 2014, Ren et al., 2017).

2.3.2 Advantages over maximum conditional likelihood training

Motivated by similar concerns about joint likelihood training, Jebara and Pentland (1999) introduce a method to explicitly optimize the conditional likelihood for a particular LVM, the Gaussian mixture model. They replace the conditional likelihood with a more tractable lower bound, and then monotonically increase this bound via a coordinate ascent algorithm they call conditional expectation maximization (CEM). Chen et al. (2015) instead use a variant of backpropagation to optimize the conditional likelihood of a supervised topic model.

One concern about the conditional likelihood objective is that it exclusively focuses on the prediction task; it need not lead to good models of the data , and cannot incorporate unlabeled data. In contrast, our prediction-constrained approach allows a principled tradeoff between optimizing the marginal likelihood of data and the conditional likelihood of labels given data.

2.3.3 Advantages over label replication

We are not the first to notice that high-dimensional data can swamp the influence of low-dimensional labels . Among practitioners, one common workaround to this imbalance is to retain the symmetric maximum likelihood objective of Eq. (2), but to replicate each label as if it were observed times per document: . Applied to supervised LDA, label replication leads to an alternative power sLDA topic model (Zhang and Kjellström, 2014).

Label replication still leads to nearly the same per-document joint density as in Eq. (1), except that the likelihood density is raised to the -th power: . While label replication can better “balance” the relative sizes of and when , performance gains over standard supervised LDA are often negligible (Zhang and Kjellström, 2014), because this approach does not address the assymmetry issue. To see why, we examine the label-replicated training objective:


This objective does not contain any direct penalty on the predictive density , which is the fundamental idea of our prediction-constrained approach and a core term in the objective of Eq. (5). Instead, only the symmetric joint density is maximized, with training assuming both data and replicated labels are present. It is easy to find examples where the optimal solution to this objective performs poorly on the target task of predicting given only , because the training has not directly prioritized this asymmetric prediction. In later sections such as the case study in Fig. 2, we provide intuition-building examples where maximum likelihood joint training with label replication fails to give good prediction performance for any value of the replication weight, while our PC approach can do better when is sufficiently large.

Example: Label replication may lead to poor predictions.

Even when the number of replicated labels , the optimal solution to the label-replicated training objective of Eq. (10) may be suboptimal for the prediction of given . To demonstrate this, we consider a toy example involving two-component Gaussian mixture models.

Consider a one-dimensional data set consisting of six evenly spaced points, . The three points where have positive labels , while the rest have negative labels . Suppose our goal is to fit a mixture model with two Gaussian components to these data, assuming minimal regularization (that is, sufficient only to prevent the probabilities of clusters and targets from being exactly 0 or 1). Let indicate the (hidden) mixture component for .

If , the term will dominate in Eq. (10). This term can be optimized by setting , and the probability of to close to 0 or 1 depending on the cluster. In particular, we choose and . If one computes the maximum likelihood solution to the remaining parameters given these assignments of , the resulting labels-from-data likelihood equals , and two points are misclassified. Misclassification occurs because the two clusters have significant overlap.

However, there exists an alternative two-component mixture model that yields better labels-given-data likelihood and makes fewer mistakes. We set the cluster centers to and , and the cluster variances to and . Under this model, we get a labels-given-data likelihood of , and only one point is misclassified. This solution achieves a lower misclassification rate by choosing one narrow Gaussian cluster to model the adjacent positive points correctly, while making no attempt to capture the positive point at . Therefore, the solution to Eq. (10) is suboptimal for making predictions about given .

This counter-example also illustrates the intuition behind why the replicated objective fails: increasing the replicates of forces to take on a value that is predictive of during training, that is, to get as close to 1 as possible. However, there are no guarantees on which is necessary for predicting given . See Fig. 2 for an additional in-depth example.

2.3.4 Advantages over posterior regularization

The posterior regularization (PR) framework introduced by Graça et al. (2008), and later refined in Ganchev et al. (2010), is notable early work which applied explicit performance constraints to latent variable model objective functions. Most of this work focused on models for only two local random variables: data and hidden variables , without any explicit labels . Mindful of this, we can naturally express the PR objective in our notation, explaining data explicitly via an objective function and incorporating labels only later in the performance constraints.

The PR approach begins with the same overall goals of the expectation-maximization treatment of maximum likelihood inference: frame the problem as estimating an approximate posterior for each latent variable set , such that this approximation is as close as possible in KL divergence to the real (perhaps intractable) posterior . Generally, we select the density to be from a tractable parametric family with free parameters restricted to some parameter space which makes a valid density. This leads to the objective


Here, the function is a strict lower bound on the data likelihood of Eq. (2). The popular EM algorithm optimizes this objective via coordinate descent steps that alternately update variational parameters and model parameters . The PR framework of Graça et al. (2008) adds additional constraints to the approximate posterior so that some additional loss function of interest, over both observed and latent variables, has bounded value under the distribution :


For our purposes, one possible loss function could be the negative log likelihood for the label : . It is informative to directly compare the PR constraint above with the PC objective of Eq. (6). Our approach directly constrains the expected loss under the true hidden-variable-from-data posterior :


In contrast, the PR approach in Eq. (13) constrains the expectation under the approximate posterior . This posterior does not have to stay close to true hidden-variable-from-data posterior . Indeed, when we write the PR objective in unconstrained form with Lagrange multiplier , and assume the loss is the negative label log-likelihood, we have:


Shown this way, we reach a surprising conclusion: the PR objective reduces to a lower bound on the symmetric joint likelihood with labels replicated times. Thus, it will inherit all the problems of label replication discussed above, as the optimal training update for incorporates information from both data and labels . However, this does not train the model to find a good approximation of , which we will show is critical for good predictive performance.

2.3.5 Advantages over maximum entropy discrimination and regularized Bayes

Another key thread of related work putting constraints on approximate posteriors is known as maximum entropy discrimination (MED), first published in Jaakkola et al. (1999b) with further details in followup work (Jaakkola et al., 1999a, Jebara, 2001). This approach was developed for training discriminative models without hidden variables, where the primary innovation was showing how to manage uncertainty about parameter estimation under max-margin-like objectives. In the context of LVMs, this MED work differs from standard EM optimization in two important and separable ways. First, it estimates a posterior for global parameters instead of a simple point estimate. Second, it enforces a margin constraint on label prediction, rather than just maximizing log probability of labels. We note briefly that Jaakkola et al. (1999a) did consider a MED objective for unsupervised latent variable models (see their Eq. 48), where the constraint is directly on the expectation of the lower-bound of the log data likelihood. The choice to constrain the data likelihood is fundamentally different from constraining the labels-given-data loss, which was not done for LVMs by the original MED work yet is more aligned with our focus with high-quality predictions.

The key application MED to supervised LVMs has been Zhu et al. (2012)’s MED-LDA, an extension of the LDA topic model based on a MED-inspired training objective. Later work developed similar objectives for other LVMs under the broad name of regularized Bayesian inference (Zhu et al., 2014). To understand these objectives, we focus on Zhu et al. (2012)’s original unconstrained training objectives for MED-LDA for both regression (Problem 2, Eq. 8 on p. 2246) and classification (Problem 3, Eq. 19 on p. 2252), which can be fit into our notation111 We note an irregularity between the classification and regression formulation of MED-LDA published by  Zhu et al. (2012): while classification-MED-LDA included labels only the loss term, the regression-MED-LDA included two terms in the objective that penalize reconstruction of : one inside the likelihood bound term using a Gaussian likelihood as well as inside a separate epsilon-insensitive loss term. Here, we assume that only the loss term is used for simplicity. as follows:

Here is a scalar emphasizing how important the loss function is relative to the unsupervised problem, is some prior distribution on global parameters, and is the same lower bound as in Eq. (11). We can make this objective more comparable to our earlier objectives by performing point estimation of instead of posterior approximation, which is reasonable in moderate to large data regimes, as the posterior for the global parameters will concentrate. This choice allows us to focus on our core question of how to define an objective that balances data and labels , rather than the separate question of managing uncertainty during this training. Making this simplification by substituting point estimates for expectations, with the KL divergence regularization term reducing to , and the MED-LDA objective becomes:


Both this objective and Graça et al. (2008)’s PR framework consider expectations over the approximate posterior , rather than our choice of the data-only posterior . However, the key difference between MED-LDA and the PR objectives is that the MED-LDA objective computes the loss of an expected prediction (), while the earlier PR objective in Eq. (13) penalizes the full expectation of the loss (). Earlier MED work (Jaakkola et al., 1999a) also suggests using an expectation of the loss, . Decision theory argues that the latter choice is preferable when possible, since it should lead to decisions that better minimize loss under uncertainty. We suspect that MED-LDA chooses the former only because it leads to more tractable algorithms for their chosen loss functions.

Motivated by this decision-theoretic view, we consider modifying the MED-LDA objective of Eq. (16) so that we take the full expectation of the loss. This swap can also be justified by assuming the loss function is convex, as are both the epsilon-insensitive loss and the hinge loss used by MED-LDA, so that Jensen’s inequality may be used to bound the objective in Eq. (16) from above. The resulting training objective is:


In this form, we see that we have recovered the symmetric maximum likelihood objective with label replication from Eq. (10), with replicated times. Thus, even this MED effort fails to properly handle the asymmetry issue we have raised, possibly leading to poor generalization performance.

2.4 Relationship to Semi-supervised Learning Frameworks

Often, semi-supervised training is performed via optimization of the joint likelihood , using the EM algorithm to impute missing data (Nigam et al., 1998). Other work falls under the thread of “self-training”, where a model trained on labeled data only is used to label additional data and then retrained accordingly. Chang et al. (2007) incorporated constraints into semi-supervised self-training of an upstream hidden Markov model (HMM). Starting with just a small labeled dataset, they iterate between two steps: (1) train model parameters via maximum likelihood estimation on the fully labeled set, and (2) expand and revise the fully labeled set via a constraint-driven approach. Given several candidate labelings for some example, their step 2 reranks these to prefer those that obey some soft constraints (for example, in a bibliographic labeling task, they require the “title” field to always appear once). Importantly, however, this work’s subprocedure for training from fully labeled data is a symmetric maximum likelihood objective, while our PC approach more directly encodes the asymmetric structure of prediction tasks.

Other work deliberately specifies prior domain knowledge about label distributions, and penalizes models that deviate from this prior when predicting on unlabeled data. Mann and McCallum (2010) propose generalized expectation (GE) constraints, which extend their earlier expectation regularization (XR) approach (Mann and McCallum, 2007). This objective has two terms: a conditional likelihood objective, and a new regularization term comparing model predictions to some weak domain knowledge:


Here, indicates some expected domain knowledge about the overall labels-given-data distribution, while is the predicted labels-given-data distribution under the current model. The distance function , weighted by , penalizes predictions that deviate from the domain knowledge. Unlike our PC approach, this objective focuses exclusively on the label prediction task and does not at all incorporate the notion of generative modeling.

3 Case Study: Prediction-constrained Mixture Models

We now present a simple case study applying prediction-constrained training to supervised mixture models. Our goal is to illustrate the benefits of our prediction-constrained approach in a situation where the marginalization over in Eq. (5) can be computed exactly in closed form. This allows direct comparison of our proposed PC training objective to alternatives like maximum likelihood, without worry about how approximations needed to make inference tractable affect either objective.

Consider a simple supervised mixture model which generates data pairs , as illustrated in Fig. 1(b). This mixture model assumes there are possible discrete hidden states, and that the only hidden variable at each data point is an indicator variable: , where indicates which of the clusters point is assigned to. For the mixture model, we parameterize the densities in Eq. (1) as follows:


The parameter set of the latent variable prior is simple: , where is a vector of positive numbers that sum to one, representing the prior probability of each cluster.

We emphasize that the data likelihood and label likelihood are left in generic form since these are relatively modular: one could apply the mixture model objectives below with many different data and label distributions, so long as they have valid densities that are easy to evaluate and optimize for parameters . Fig. 1(b) happens to show the particular likelihood choices we used in our toy data experiments (Gaussian distribution for , bernoulli distribution for ), but we will develop our PC training for the general case. The only assumption we make is that each of the clusters has a separate parameter set: and .

Related work on supervised mixtures.

While to our knowledge, our prediction-constrained optimization objective is novel, there is a large related literature on applying mixtures to supervised problems where the practioner observes pairs of data covariates and targets . One line of work uses generative models with factorization structure like Fig. 1, where each cluster has parameters for generating data and targets . For example, Ghahramani and Jordan (1993, Sec. 4.2) consider nearly the same model as in our toy experiments (except for using a categorical over labels instead of a Bernoulli). They derive an Expectation Maximization (EM) algorithm to maximize a lower bound on the symmetric joint log likelihood . Later applied work has sometimes called such models Bayesian profile regression when the targets are real-valued (Molitor et al., 2010). These efforts have seen broad extensions to generalized linear models especially in the context of Bayesian nonparametric priors like the Dirichlet process fit with MCMC sampling procedures (Shahbaba and Neal, 2009, Hannah et al., 2011, Liverani et al., 2015). However, none of these efforts correct for the assymmetry issues we have raised, instead simply using the symmetric joint likelihood.

Other work takes a more discriminative view of the clustering task. Krause et al. (2010) develop an objective called Regularized Information maximization which learns a conditional distribution for that preserves information from the data . Other efforts do not estimate probability densities at all, such as “supervised clustering” (Eick et al., 2004). Many applications of this paradigm exist (Finley and Joachims, 2005, Al-Harbi and Rayward-Smith, 2006, DiCicco and Patel, 2010, Peralta et al., 2013, Ramani and Jacob, 2013, Grbovic et al., 2013, Peralta et al., 2016, Flammarion et al., 2016, Ismaili et al., 2016, Yoon et al., 2016, Dhurandhar et al., 2017).

3.1 Objective function evaluation and parameter estimation.

Computing the data log likelihood.

The marginal likelihood of a single data example , marginalizing over the latent variable , can be computed in closed form via the function:

Computing the label given data log likelihood.

Similarly, the likelihood of labels given data, marginalizing away the latent variable , can be computed in closed form:

PC parameter estimation via gradient descent.

Our original unconstrained PC optimization problem in Eq. (5) can thus be formulated for mixture models using this closed form marginal probability functions and appropriate regularization terms :


We can practically solve this optimization objective via gradient descent. However, some parameters such as live in constrained spaces like the dimensional simplex. To handle this, we apply invertible, one-to-one transformations from these constrained spaces to unconstrained real spaces and apply standard gradient methods easily.

In practice, for training supervised mixtures we use the Adam gradient descent procedure (Kingma and Ba, 2014), which requires specifying some baseline learning rate (we search over a small grid of 0.1, 0.01, 0.001) which is then adaptively scaled at each parameter dimension to improve convergence rates. We initialize parameters via random draws from reasonable ranges and run several thousand gradient update steps to achieve convergence to local optima. To be sure we find the best possible solution, we use many (at least 5, preferably more) random restarts for each possible learning rate and choose the one snapshot with the lowest training objective score.

3.2 Toy Example: Why Asymmetry Matters

We now consider a small example to illustrate one of our fundamental contributions: that PC training is often superior to symmetric maximum likelihood training with label replication, in terms of finding models that accurately predict labels given data . We will apply supervised mixture models to a simple toy dataset with data on the real line and binary labels . The observed training dataset is shown in the top rows of Fig. 2 as a stacked histogram. We construct the data by drawing data from three different uniform distributions over distinct intervals of the real line, which we label in order from left to right for later reference: interval A contains 175 data points , with a roughly even distribution of positive and negative labels; interval B contains 100 points with purely positive labels; interval C contains 75 points with purely negative labels. Stacked histograms of the data distribution, colored by the assigned label, can be found in Fig. 2.

Fig. 2: Toy example from Sec. 3.2: asymmetric prediction constrained (PC) training predicts labels better than symmetric joint maximum likelihood training with label replication (ML+rep). Top rows: Estimated 2-cluster Gaussian mixture model for each training procedure under different weight values , taking the best of many initializations using the relevant training objective function. Curves show the estimated 1D Gaussian distribution for each cluster. Upper left text in each panel gives the estimated probability that each cluster will emit a positive label. Colors are assigned so that red cluster has higher probability of emitting positive labels. Stacked histograms of 1-dimensional training dataset overlaid in background (blue shading means , red means ). Bottom row: Area-under-the-ROC-curve and error rate scores for predicting labels from data on training data, using the best solution (as ranked by each training objective) across different weight values . Final panel shows negative log likelihood of data (normalized by number of data points) across same values.

We now wish to train a supervised mixture model for this dataset. To fully specify the model, we must define concrete densities and parameter spaces. For the data likelihood , we use a 1D Gaussian distribution , with two parameters for each cluster . The mean parameter can take any real value, while the standard deviation is positive with a small minimum value to avoid degeneracy: . For the label likelihood , we select a Bernoulli likelihood , which has one parameter per cluster: , where defines the probability that labels produced by cluster will be positive. For this example, we fix the model structure to exactly total clusters for simplicity.

We apply very light regularization on only the and parameters:


These choices ensure that MAP estimates of and are unique and always exist in numerically valid ranges (not on boundary values of exactly 0 or 1). This is helpful for the closed-form maximization step we use for the EM algorithm for the ML+rep objective.

When using this model to explain this dataset, there is a fundamental tension between explaining the data and the labels : no one set of parameters will outrank all other parameters on both objectives. For example, standard joint maximum likelihood training (equivalent to our PC objective when ) happens to prefer a mixture model with two well-separated Gaussian clusters with means around 0 and 1.5. This gives reasonable coverage of data density , but has quite poor predictive performance , because the left cluster is centered over interval A (a non-separable even mix of positive and negative examples), while the right cluster explains both B and C (which together contain 100 positive and 75 negative examples).

Our PC training objective allows prioritizing the prediction of by increasing the Lagrange multiplier weight . Fig. 2 shows that for , the PC objective prefers the solution with one cluster (colored red) exclusively explaining interval B, which has only positive labels. The other cluster (colored blue), has wider variance to cover all remaining data points. This solution has much lower error rate ( vs. ) and higher AUC values ( vs. ) than the basic solution. Of course, the tradeoff is a visibly lower likelihood of the training data , since the higher-variance blue cluster does less well explaining the empirical distribution of . As increases beyond 4, the quality of label prediction improves slightly as the decision boundaries get even sharper, but this requires the blue background cluster to drift further away from data and reduce data likelihood even more. In total, this example illustrates how PC training enables the practitioner to explore a range of possible models that tradeoff data likelihood and prediction quality.

In contrast, any amount of label replication for standard maximum likelihood training does not reach the prediction quality obtained by our PC approach. We show trained models for replication weights values equal to 1, 4, 16, and 64 in Fig. 2 (we use common notation for simplicity). For all values , we see that symmetric joint “ML+rep” training finds the same solution: Gaussian clusters that are exclusively dedicated to either purely positive or purely negative labels. This occurs because at training time, both and are fully observed, and thus the replicated presence of strongly cues which cluster to assign and allows completely perfect label classification. However, when we then try asymmetric prediction of given only on the same training data, we see that performance is much worse: the error rate is roughly 0.4 while our PC method achieved near 0.25. It is important to stress that no amount of label replication would fix this, because the asymmetric task of predicting given only is not the focus of the symmetric joint likelihood objective.

3.3 Toy Example: Advantage of Semisupervised PC Training

(a) PC: Prediction-constrained (b) ML+rep: Maximum likelihood with label replication
Fig. 3: Toy example from Sec. 3.3: Estimated supervised mixture models produced by PC training (a) and ML+rep (b) for semi-supervised tasks with few labeled examples. Each panel shows the 2D elliptical contours of the estimated cluster Gaussian mixture model which scored best under each training objective using the indicated weight and percentage of examples which have observed labels at training, which varies from 3% to 100%. Upper text in each panel gives the estimated probability that each cluster will emit a positive label. Colors are assigned so that red cluster has higher probability of emitting positive labels. In the background of each panel is a scatter plot of the first two dimensions of data , with each point colored by its binary label (grey = negative, black = positive).
Fig. 4: Toy example from Sec. 3.3: Each panel shows line plots of performance metrics as the PC or replication weight increases, for particular percentage of data that is labeled. Top row shows label prediction error rate (lower is better), and bottom row shows negative data likelihood (lower is better). For visualizations of corresponding parameters, see Fig. 3.

Next, we study how our PC training objective enables useful analysis of semi-supervised datasets, which contain many unlabeled examples and few labeled examples. Again, we will illustrate clear advantages of our approach over standard maximum likelihood training in prediction quality.

The dataset is generated in two stages. First, we generate 5000 data vectors drawn from a mixture of 2 well-separated Gaussians with diagonal covariance matrices:

Next, we generate binary labels according to a fixed threshold rule which uses only the absolute value of the second dimension of :


While the full data vectors are 5-dimensional, we can visualize the first two dimensions of as a scatterplot in Fig. 3. Each point is annotated by its binary label : 0-labeled data points are grey ’x’ markers while 1-labeled points are black ’o’ markers. Finally, we make the problem semi-supervised by selecting some percentage of the 5000 data points to keep labeled during training. For example if , then we train using 2500 labeled pairs randomly selected from the full dataset as well as the remaining 2500 unlabeled data points. Our model specification is the same as the previous example: Gaussian with diagonal covariance for , Bernoulli likelihood for , and the same light regularization as before to allow closed-form, numerically-valid M-steps when optimizing the ML+rep objective via EM.

We have deliberately constructed this dataset so that a supervised mixture model is misspecified. Either the model will do well at capturing the data density by covering the two well-separated blobs with equal-covariance Gaussians, or it will model the predictive density well by using a thin horizontal Gaussian to model the black points as well as a much larger background Gaussian to capture the rest. With only 2 clusters, no single model can do well at both.

Our PC approach provides a range of possible models to consider, one for each value of , which tradeoff these two objectives. Line plots showing overall performance trends for data likelihood and prediction quality are shown in Fig. 4, while the corresponding parameter visualizations are shown in Fig. 3. Overall, we see that PC training when , which is equivalent to standard ML training, yields a solution which explains the data well but is poor at label prediction. For all tested fractions of labeled data , as we increase there exists some critical point at which this solution is no longer prefered and the objective instead favors a solution with near-zero error rate for label prediction. For , we find a solution with near zero error rate at , while for we see that it takes .

In contrast, when we test symmetric ML training with label replication across many replication weights , we see big differences between plentiful labels and scarce labels . When enough labeled examples are available, high replication weights do favor the same near-zero error rate solution found by our PC approach. However, there is some critical value of below which this solution is no longer favored, and instead the prefered solution for label replication is a pathological one: two well-separated clusters that explain the data well but have extreme label probabilities . Consider the solution for ML+rep in Fig. 3. The red cluster explains the left blob of unlabeled data (containing about 2400 data points) as well as all positive labels observed at training, which occur in both the left and right blobs (only 150 total labels exist, of which about half are positive). The symmetric joint ML objective weighs each data point, whether labeled or unlabeled, equally when updating the parameters that control no matter how much replication occurs. Thus, enough unlabeled points exert strong influence for the particular well-separated blob configuration of the data density , and the few labeled points can be easily explained as outliers to the two blobs. In contrast, our PC objective by construction allows upweighting the influence of the asymmetric prediction task on all parameters, including . Thus, even when replication happens to yield good predictions when all labels are observed, it can yield pathologies with few labels that our PC easily avoids.

4 Case Study: Prediction-constrained Topic Models

We now present a much more thorough case-study of prediction-constrained topic models, building on latent Dirichlet allocation (LDA) (Blei et al., 2003) and its downstream supervised extension sLDA (McAuliffe and Blei, 2007). The unsupervised LDA topic model takes as observed data a collection of documents, or more generally, groups of discrete data. Each document is represented by counts of discrete word types or features, . We explain these observations via latent clusters or topics, such that each document exhibits mixed-membership across these topics. Specifically, in terms of our general downstream LVM model family the model assumes a hidden variable such that is a vector of positive numbers that sum to one, indicating which fraction of the document is explained by each topic . The generative model is:


Here, the hidden variable prior density is chosen to be a symmetric Dirichlet with parameters , where is a scalar. Similarly, the data likelihood parameters are defined as , where each topic has a parameter vector of positive numbers (one for each vocabulary term) that sums to one. The value defines the probability of generating word under topic . Finally, we assume that the size of document  is observed as .

In the supervised setting, we assume that each document also has an observed target value . For our applications, we’ll assume this is one or more binary labels, so , but we emphasize other types of values are easily possible via generalized linear models (McAuliffe and Blei, 2007). Standard supervised topic models like sLDA assume labels and word counts are conditionally independent given topic probabilities , via the label likelihood:


where is the logit function, and is a vector of real-valued regression parameters. Under this model, large positive values imply that high usage of topic in a given document (larger ) will lead to predictions of a positive label . Large negative values imply high topic usage leads to a negative label prediction .

The original sLDA model (McAuliffe and Blei, 2007) represents the count likelihood via independent assignments of word tokens to topics, and generates labels , where is a vector on the dimensional probability simplex given the empirical distribution of the token-to-topic assignments: and . To enable more efficient inference algorithms, we analytically marginalize these topic assignments away in Eq. (27,28).

PC objective for sLDA.

Applying the PC objective of Eq. (5) to the sLDA model gives:


Computing and involves marginalizing out the latent variables :


Unfortunately, these integrals are intractable. To gain traction, we first contemplate an objective that instantiates rather than marginalizes away:


However, this objective is simply a version of maximum likelihood with label-replication from Sec. 2.3, albeit with hidden variables instantiated rather than marginalized. The same poor prediction quality issues will arise due to its inherent symmetry. Instead, because we wish to train under the same assymetric conditions needed at test time, where we have but not , we do not instantiate as a free variable but fix to a deterministic mapping of the words to the topic simplex. Specifically, we fix to the maximum a-posteriori (MAP) solution , which we write as a deterministic function: . We show in Sec. 4.1 that this deterministic embedding of any document’s data onto the topic simplex is easy to compute. Our chosen embedding can be seen as a feasible approximation to the full posterior needed in Eq. (31). This choice which respects the need to use the same embedding of observed words into low-dimensional in both training and test scenarios.

We can now write a tractable training objective we wish to minimize:


This objective is both tractable to evaluate and fixes the asymmetry issue of standard sLDA training, because the model is forced to learn the embedding function which will be used at test time.

Previous training objectives for sLDA.

Originally, the sLDA model was trained via a variational EM algorithm that optimizes a lower bound on the marginal likelihood of the observed words and labels (McAuliffe and Blei, 2007); MCMC sampling for posterior estimation is also possible. This treatment ignores the cardinality mismatch and assymetry issues, making it difficult to make good predictions of given under conditions of model mismatch. Alternatives like MED-LDA (Zhu et al., 2012) offered alternative objectives which try to enforce constraints on the loss function given expectations under the approximate posterior, yet this objective still ignores the crucial asymmetry issue. We also showed earlier in Sec. 2.3 that some MED objectives can be reduced to ineffective maximum likelihood with label-replication.

Recently, Chen et al. (2015) developed backpropagation methods called BP-LDA and BP-sLDA for the unsupervised and supervised versions of LDA. They train using extreme cases of our end-to-end weighted objective in Eq. (33), where for supervised BP-sLDA the entire data likelihood term is omitted completely, and for unsupervised BP-LDA the entire label likelihood is omitted. In contrast, our overriding goal of guaranteeing some minimum prediction quality via our PC objective in Eq. (5) leads to a Lagrange multiplier which allows us to systematically balance the generative and discriminative objectives. BP-sLDA offers no such tradeoff, and we will see in later experiments that while its label predictions are sometimes good, the underlying topic model is quite terrible at explaining heldout data and yields difficult-to-interpret topic-word distributions.

4.1 Inference and learning for Prediction-Constrained LDA

Fitting the sLDA model to a given dataset using our PC optimization objective in Eq. (33) requires two concrete procedures: per-document inference to compute the hidden variable , and global parameter estimation of the topic-word parameters and logistic regression weight vector . First, we show how the MAP embedding can be computed via several iterations of an exponentiated gradient procedure with convex structure. Second, we show how we can differentiate through the entire objective to perform gradient descent on our parameters of interest and . While in our experiments, we assume that the prior concentration parameter is a fixed constant, this could easily be optimized as well via the same procedure.

MAP inference via exponentiated gradient iterations.

Sontag and Roy (2011) define the document-topic MAP estimation problem for LDA as:


This problem is convex for and non-convex otherwise. For the convex case, they suggest an iterative exponentiated gradient algorithm (Kivinen and Warmuth, 1997). This procedure begins with a uniform probability vector, and iteratively performs elementwise multiplication with the exponentiated gradient until convergence using a scalar stepsize :


With small enough steps, the final result after iterations converges to the MAP solution. We thus define our embedding function to be the outcome of iterations of the above procedure. We find iterations and a step size of work well. Line search for could reduce the number of iterations needed (though increase per-iteration costs).

Importantly, Taddy (2012) points out that while the general non-convex case has no single MAP solution for in the simplex due to the multimodal sparsity-promoting Dirichlet prior, a simple reparameterization into the softmax basis (MacKay, 1997) leads to a unimodal posterior and thus a unique MAP in this reparameterized space. Elegantly, this softmax basis solution for a particular has the same MAP estimate as the simplex MAP estimate for the “add one” posterior: . Thus, we can use our exponentiated gradient procedure to reliably perform natural parameter MAP estimation even for via this “add one” trick.

Global parameter estimation via stochastic gradient descent.

To optimize the objective in Eq. (33), we realize first that the iterative MAP estimation function above is fully differentiable with respect to the parameters and , as are the probability density functions , and . This means the entire objective is differentiable and modern gradient descent methods may be applied easily. Of course, this requires standard transformations of constrained parameters like the topic-word distributions from the simplex to unrestricted real vectors. Once the loss function is specified via unconstrained parameters, we perform automatic differentiation to compute gradients and then perform gradient descent via the Adam algorithm (Kingma and Ba, 2014), which easily allows stochastically sampling minibatches of data for each gradient update. In practice, we have developed Python implementations based on both Autograd (Maclaurin et al., 2015) and Tensorflow (Abadi et al., 2015), which we plan to release to the public.

Earlier work by Chen et al. (2015) optimized their fully discriminative objective via a mirror descent algorithm directly in the constrained parameters , using manually-derived gradient computations within a heroically complex implementation in the C# language. Our approach has the advantage of easily extending to other supervised loss functions without need to derive and implement gradient calculations, although the automatic differentation can be slow.

Hyperparameter selection.

The key hyperparameter of our prediction-constrained LDA algorithm is the Lagrange multiplier . Generally, for topic models of text data needs to be on the order of the number of tokens in the average document, though it may need to be much larger depending on how much tension exists between the unsupervised and supervised terms of the objective. If possible, we suggest trying a range of logarithmically spaced values and selecting the best on validation data, although this requires expensive retraining at each value. This can be somewhat mitigated by using the final parameters at one value as the initial parameters at the next value, although this may not escape to new preferred basins of attraction in the overall non-convex objective.

4.2 Supervised LDA Experimental Results

We now assess how well our proposed PC training of sLDA, which we hereafter abbreviate as PC-LDA, achieves its simultaneous goals of solid heldout prediction of labels given while maintaining reasonably interpretable explanations of words . We test the first goal by comparing to discriminative methods like logistic regression and supervised topic models, and the latter by comparing to unsupervised topic models. For full descriptions of all datasets and protocols, as well as more results, please see the appendix.


Our discriminative baselines include logistic regression (with a validation-tuned regularizer), the fully supervised BP-sLDA algorithm of Chen et al. (2015), and the supervised MED-LDA Gibbs sampler (Zhu et al., 2013), which should improve on the earlier variational methods of the original MED-LDA variational algorithm in Zhu et al. (2012). We also consider own implementation of standard coordinate-ascent variational inference for both unsupervised (VB LDA) and supervised (VB sLDA) topic models. Finally, we consider a vanilla Gibbs sampler for LDA, using the Mallet toolbox (McCallum, 2002). We use third-party public code when possible for single-label-per-document experiments, but only our own PC-LDA and VB implementations support multiple binary labels per document, which occur in our later Yelp review label prediction and electronic health record drug prediction tasks. For these datasets only, the method we call BP-sLDA is a special case of our own PCLDA implementation (removing the data likelihood term), which we have verified is comparable to the single-target-only public implementation but allows multiple binary targets.


For each dataset, we reserve two distinct subsets of documents: one for validation of key hyperparameters, and another to report heldout metrics. All topic models are run from multiple random initializations of (for fairness, all methods use same set of predefined initializations of these parameters). We record point estimates of topic-word parameters and logistic regression weights at defined intervals throughout training, and we select the best pair on the validation set (early stopping). For Bayesian methods like GibbsLDA, we select Dirichlet concentration hyperparameters via a small grid search on validation data, while for PC-LDA and BP-sLDA we set as recommended by Chen et al. (2015)