Contextual BERT: Conditioning the Language Model Using a Global State
BERT is a popular language model whose main pre-training task is to fill in the blank, i.e., predicting a word that was masked out of a sentence, based on the remaining words. In some applications, however, having an additional context can help the model make the right prediction, e.g., by taking the domain or the time of writing into account. This motivates us to advance the BERT architecture by adding a global state for conditioning on a fixed-sized context. We present our two novel approaches and apply them to an industry use-case, where we complete fashion outfits with missing articles, conditioned on a specific customer. An experimental comparison to other methods from the literature shows that our methods improve personalization significantly.
Since its publication, the BERT model by \newciteDBLP:conf/naacl/DevlinCLT19 has enjoyed great popularity in the natural language processing (NLP) community. To apply the model to a specific problem, it is commonly pre-trained on large amounts of unlabeled data, and subsequently fine-tuned on a target task. During both stages, the model’s only input is a variably-sized sequence of words.
There are use-cases, however, where having an additional context can help the model. Consider a query intent classifier whose sole input is a user’s text query. Under the assumption that users from different age groups and professions express the same intent in different ways, the classifier would benefit from having access to that user context in addition to the query. Alternatively, one might consider training multiple models on separate, age group- and profession-specific samples. However, this approach does not scale well, requires more training data, and does not share knowledge between the models.
To the best of our knowledge, there is a shortcoming in effective methods for conditioning BERT on a fixed-sized context. Motivated by this, and inspired by the graph-networks perspective on self-attention models , we advance BERT’s architecture by adding a global state that enables conditioning. With our proposed methods [GS] and [GSU], we combine two previously independent streams of work. The first is centered around the idea of explicitly adding a global state to BERT, albeit without using it for conditioning. The second is focused on injecting additional knowledge into the BERT model. By using a global state for conditioning, we enable the application of BERT in a range of use-cases that require the model to make context-based predictions.
We use the outfit completion problem to test the performance of our new methods: The model predicts fashion items to complete an outfit and has to account for both style coherence and personalization. For the latter, we condition on a fixed-sized customer representation containing information such as customer age, style preferences, hair color, and body type. We compare our methods against two others from the literature and observe that ours are able to provide more personalized predictions.
2 Related Work
BERT’s Global State
In the original BERT paper, \newciteDBLP:conf/naacl/DevlinCLT19 use a [CLS] token which is prepended to the input sequence (e.g., a sentence of natural language). The assumption is that the model aggregates sentence-wide, global knowledge at the position of the [CLS] token. This intuition was confirmed through attention score analysis , however, the BERT architecture does not have an inductive bias that aids it. Recent work therefore treats the [CLS] token differently. \newciteDBLP:journals/corr/abs-2007-14062 constrain their BERT variant Big Bird to local attention only, with the exception that every position may always attend to [CLS] regardless of its spatial proximity.
Ke2020RethinkingPE also observe that the [CLS] attention exhibits peculiar patterns. This motivates them to introduce a separate set of weights for attending to and from [CLS]. The authors thereby explicitly encode into the architecture that the sequence’s first position has a special role and different modality than the other positions. The result is an increased performance on downstream GLUE tasks.
It is important to note that all related work on BERT’s global state does not use the global state for conditioning. Instead, the architectural changes are solely being introduced to improve the performance on non-contextual NLP benchmarks.
Conditioning on a Context
To the best of our knowledge, \newciteDBLP:journals/corr/abs-1812-06705 are the first to provide sentence-wide information to the model to ease the masked language model (MLM) pre-training task. The authors inject the target label (e.g., positive or negative review) of sentiment data by adding it to the [CLS] token embedding. In a similar application, \newciteDBLP:journals/access/LiFXYWJLX20 process the context separately and subsequently combine it with the model output to make a sentiment prediction.
DBLP:journals/corr/abs-2004-01881 condition on richer information, namely an intent, which can be thought of as a task descriptor given to the model. The intent is represented in text form, is variably sized, and prepended to the sequence. This is very similar to a wide range of GPT  applications.
DBLP:conf/kdd/ChenHXGGSLPZZ19 condition on a customer’s variably-sized click history using a Transformer . The most similar to our work are \newcitewu2020ssept who personalize by concatenating every position in the input sequence with a user embedding – method [C] from Section 3. Their approach, however, lacks an architectural bias that makes the model treat the user embedding as global information.
BERT as a Graph Neural Network (GNN)\newcite
DBLP:journals/corr/abs-1806-01261 introduce a framework that unites several lines of research on GNNs. In the Appendix, the authors show that – within their framework – the Transformer architecture is a type of GNN; \newcitejoshi2020transformers supports this finding. In both cases the observation is that a sentence can be seen as a graph, where words correspond to nodes and the computation of an attention score is the assignment of a weight to an edge between two words.
In the GNN framework, a global state is accessible from every transfer function and can be individually updated from layer to layer.
3 Conditioning BERT With a Global State
Let denote a sequence of words from a fixed-sized vocabulary . Further, let be the sequence without the th word. Recall that a vanilla BERT model  can predict the probability that a word is masked-out in a sequence ( being the masking event), conditioned on the other words in the sequence. Next, we introduce four methods to additionally condition BERT on a context vector , which allows it to predict .
Similar to \newcitewu2020ssept, we concatenate the context vector with every position in the input sequence. Let denote the embedding of word at position . The resulting input matrix is . is a trainable weight matrix that reduces the input dimensionality.
New Position [NP]
This method adds a new position to the input sequence at which the context is stored. It is comparable to how \newciteDBLP:journals/corr/abs-1812-06705 add label information. Instead of feeding the word sequence into the model, we prepend the transformed context to the sequence, where is a trainable weight matrix. The resulting input is . The model’s attention masks are adjusted, such that every position can attend to the new first position.
Global State [GS]
Our method is inspired by the GNN perspective on BERT. Its implementation is similar to the way the Transformer  decoder attends to the encoder. [GS] treats the context as a read-only global state from which the internal representations can be updated. In order to adjust the architecture of BERT accordingly, we insert a global state attention layer between the intra-sequence attention and the (originally) subsequent feed-forward neural network (FNN). Figure 1 shows how the inserted elements fit into the vanilla BERT block.
More formally, let be the output of the th BERT block (of which there are ); let be the model input; and let be the global state derived from the context vector using a non-linear transformation .
With our modification, is defined by first performing the normal intra-sequence attention as in BERT
multi-head attention can be used here as well. Then, also unchanged,
The internal representation is then updated once more by reading
and computing Lastly, the BERT layer output is computed as The definitions of , , and are identical to \newciteDBLP:conf/nips/VaswaniSPUJGKP17. The weight matrices and are not shared between layers. The layer indices of , , and the weight matrices are omitted to aid readability.
Global State With Update [GSU]
Note that in the [GS] method the global state , which is being added to the internal representation in Equation 1, is the same for all blocks. With [GSU], a global state transfer function updates the global state, with its initial value being derived from the context: . The weights of are not shared between layers. Equation 1 is updated to read from the global state belonging to the th layer:
4 Empirical Evaluation and Discussion
We evaluate the performance of our proposed methods on a real-world industry problem: personalized fashion outfit completion (see Figure 2) for Europe’s largest fashion platform. Our proprietary dataset consists of 380k outfits, created by professional stylists for individual customers. When styling a customer, i.e., putting together an outfit, the stylist has access to all customer features that we later use to condition the model. Therefore, customer data and outfit are statistically dependent.
The customer features are individually embedded using trainable, randomly initialized embedding spaces. The per-feature embedding vectors are subsequently concatenated, yielding a context vector of . Features include the customer’s age, gender, country, preferred brands/colors/styles, no-go types, clothing sizes, price preferences, and the occasion for which the outfit is needed. The second model input, namely the outfit itself, is constructed from learned embeddings for every individual fashion article, with . We stack BERT blocks with multi-head attention (eight heads). For a masked-out item, the model predicts a probability distribution over articles.
While not being an NLP dataset, our data resembles many of the important traits of a textual corpus: the vocabulary size is comparable to the one of word-piece vocabularies commonly used with BERT models. Fashion outfits are similar to sentences in that some articles appear often together (match style-wise) and others do not. Different is the typical sequence length which ranges from four to eight fashion articles, with an average length of exactly five. In contrast to sentences, outfits do not have an inherent order. To account for that we remove the positional encoding from BERT so it treats its input as a set.
Table 1 shows the results of evaluating the four different methods. We compare cross-entropy and recall@rank (r@ for short) on a randomly selected validation dataset consisting of 17k outfits that are held-out during training. The r@ is defined as the percentage of cases in which the masked-out item is among the top- most probable items, according to the model. Model parameters are counted without embedding spaces for customer features and the last dense layer.
The empirical evaluation reveals the effectiveness of using a context for making predictions. The model’s ability to replicate the stylist behavior better, i.e., achieve a higher r@, improves substantially with the addition of a context. On r@ we see a relative improvement of by using the [GSU] method for conditioning over using no customer context at all ([None]) and compared to [NP] (the best method without global state).
A comparison of the four different conditioning methods shows [GSU] to be most effective, followed by [GS], [NP], and [C]. The methods [C] and [NP] do not have any bias towards treating the context vector specially. They attend to other positions in the sequence the same way they attend to the context. The superiority of [GS] and [GSU] can presumably be explained by their explicit architectural ability to retrieve information from the global state and therefore effectively utilize the context for their prediction.
We acknowledge the differences between our outfits dataset and typical NLP benchmarks. Nonetheless we hypothesize that the effectiveness of our method translates to NLP. In particular when applied to use-cases in which the modality of context and sequence differ, e.g., for contexts comprised of numerical or categorical meta data about the text. That is because the model’s freedom to read from the context separately allows it to process the different modalities of context and input sequence adequately.
5 Conclusions and Future Work
With Contextual BERT, we presented novel ways of conditioning the BERT model. The strong performance on a real-world use-case provides evidence for the superiority of using a global state to inject context into the Transformer-based architecture. Our proposal enables the effective conditioning of BERT, potentially leading to improvements in a range of applications where contextual information is relevant.
A promising idea for follow-up work is to allow for information to flow from the sequence to the global state. Further, it would be desirable to establish a contextual NLP benchmark for the research community to compete on. This benchmark would task competitors with contextualized NLP problems, e.g., social media platform-dependent text generation or named entity recognition for multiple domains.
- For details on the GNN definition of a global state we refer the reader to Section 3.2 in \newciteDBLP:journals/corr/abs-1806-01261.
- Note that the global state sequence does not admit for attention in the sense of selecting a weighted average of multiple vectors, because it consists only of a single vector. Instead, Equation 1 is reduced to the possibility for the th layer to transform the context and update its internal state based on it.
- (2018) Relational inductive biases, deep learning, and graph networks. CoRR abs/1806.01261. External Links: Cited by: §1.
- J. Burstein, C. Doran and T. Solorio (Eds.) (2019) Proceedings of the 2019 conference of the north american chapter of the association for computational linguistics: human language technologies, NAACL-HLT 2019, minneapolis, mn, usa, june 2-7, 2019, volume 1 (long and short papers). Association for Computational Linguistics. External Links: Cited by: 4.
- (2019) What does BERT look at? an analysis of bert’s attention. CoRR abs/1906.04341. External Links: Cited by: §2.
- (2019) BERT: pre-training of deep bidirectional transformers for language understanding. See Proceedings of the 2019 conference of the north american chapter of the association for computational linguistics: human language technologies, NAACL-HLT 2019, minneapolis, mn, usa, june 2-7, 2019, volume 1 (long and short papers), Burstein et al., External Links: Cited by: §3.
- (2019) Language models are unsupervised multitask learners. OpenAI Blog 1 (8), pp. 9. Cited by: §2.
- (2017) Attention is all you need. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017, Long Beach, CA, USA, I. Guyon, U. von Luxburg, S. Bengio, H. M. Wallach, R. Fergus, S. V. N. Vishwanathan and R. Garnett (Eds.), pp. 5998–6008. External Links: Cited by: §2, §3.