Sequence Tagging with Policy-Value Networks and Tree Search
In this paper we propose a novel reinforcement learning based model for sequence tagging, referred to as MM-Tag. Inspired by the success and methodology of the AlphaGo Zero, MM-Tag formalizes the problem of sequence tagging with a Monte Carlo tree search (MCTS) enhanced Markov decision process (MDP) model, in which the time steps correspond to the positions of words in a sentence from left to right, and each action corresponds to assign a tag to a word. Two long short-term memory networks (LSTM) are used to summarize the past tag assignments and words in the sentence. Based on the outputs of LSTMs, the policy for guiding the tag assignment and the value for predicting the whole tagging accuracy of the whole sentence are produced. The policy and value are then strengthened with MCTS, which takes the produced raw policy and value as inputs, simulates and evaluates the possible tag assignments at the subsequent positions, and outputs a better search policy for assigning tags. A reinforcement learning algorithm is proposed to train the model parameters. Our work is the first to apply the MCTS enhanced MDP model to the sequence tagging task. We show that MM-Tag can accurately predict the tags thanks to the exploratory decision making mechanism introduced by MCTS. Experimental results show based on a chunking benchmark showed that MM-Tag outperformed the state-of-the-art sequence tagging baselines including CRF and CRF with LSTM.
Sequence tagging, including POS tagging, chunking, and name entity recognition, has gained considerable research attention for a few decades. Using the chunking as an example, given a sentences of text (e.g., a sentence), each of the word in a sequence receives a “tag” (class label) that expresses its phrase type.
Existing sequence tagging models are be categorized into the statistical models and the deep neural networks based models. Traditional research on sequence tagging focus on the linear statistical models, including the maximum entropy (ME) classifier (Ratnaparkhi, 1996) and maximum entropy Markov models (MEMMs) (McCallum et al., 2000). These models predict a distribution of tags for each time step and then use beam-like decoding to find optimal tag sequences. Lafferty et al. proposed conditional random fields (CRF) (Lafferty et al., 2001) to leverage global sentence level feature and solve the label bias problem in MEMM. All the linear statistical models rely heavily on hand-crafted features, e.g., the word spelling features for the task of part-of-speech. Motivated by the success of deep learning, deep neural networks based models have been proposed for sequence tagging in recent years. Most of them directly combine the deep neural networks with CRF. For example, Huang (Huang et al., 2015) used a bidirectional LSTM to automatically extract word-level representations and then combined with CRF for jointly label decoding. Ma (Ma and Hovy, 2016) introduced a neural network architecture that both word level and character level features are used, in which bidirectional LSTM, CNN, and CRF are combined. In recent years, reinforcement learning is also proposed for the task. For example, Maes et al. formalized the sequence tagging task as a Markov decision process (MDP) and used the reinforcement learning algorithm SARSA to construct optimal sequence directly in a greedy manner (Maes et al., 2007). Feng et al. proposed a novel model to address the noise problem of relation classification task caused by distant supervision, in which an instance selector designed with REINFORCE algorithm is used to assign select or delete action (label) for every sentence Feng et al. (2018).
Inspired by the reinforcement learning model of AlphaGO (Silver et al., 2016) and AlphaGO Zero (Silver et al., 2017) programs designed for the Game of Go, in this paper we propose to solve the sequence tagging with a Monte Carlo tree search (MCTS) enhanced Markov decision process (MDP). The new sequence tagging model, referred to as MM-Tag (MCTS enhanced MDP for Tagging), makes use of an MDP to model the sequential tag assignment process in sequence tagging. At each position (corresponding to a ranking position), based on the past words and tags, two long short-term memory networks (LSTM) are used to summarize the past words and tags, respectively. Based on the outputs of these two LSTMs, a policy function (a distribution over the valid tags) for guiding the tag assignment and a value function for estimating the accuracy of tagging are produced. To avoid the problem of assigning tags without utilizing the whole sentence level tags, in stead of choosing a tag directly with the raw policy predicted by the policy function, MM-Tag explores more possibilities in the whole space. The exploration is conducted with the MCTS guided by the produced policy function and value function, resulting a strengthened search policy for the tag assignment. Moving to the next iteration, the algorithm moves to the next position and continue the above process until at the end of the sentence.
Reinforcement learning is used to train the model parameters. In the training phase, at each learning iteration and for each training sentence (and the corresponding labels), the algorithm first conducts an MCTS inside the training loop, guided by the current policy function and value function. Then the model parameters are adjusted to minimize the loss function. The loss function consists of two terms: 1) the squared error between the predicted value and the final ground truth accuracy of the whole sentence tagging; and 2) the cross entropy of the predicted policy and the search probabilities for tags selection. Stochastic gradient descent is utilized for conducting the optimization.
To evaluate the effectiveness of MM-Tag, we conducted experiments on the basis of CoNLL 2000 chunking dataset. The experimental results showed that MM-Tag can significantly outperform the state-of-the-art sequence tagging approaches, including the linear statistical models of CRF and neural network-based models of BI-LSTM-CRF. We analyzed the results and showed that MM-Tag improved the performances through conducting lookahead MCTS to explore in the whole tagging space.
2. MDP formulation of sequence tagging
In this section, we introduce the proposed MM-Tag model.
2.1. Sequence tagging as an MDP
Suppose that is a sequence of words (sentence) to be labeled, and is the corresponding ground truth tag sequence. All components of are the -dimensional preliminary representations of the words, i.e., the word embedding. All components of are assumed to be selected from a finite tag set . From example, may be the set of possible part-of-speech tags. The goal of sequence tagging is to construct a model that can automatically assign a tag to each word in the input sentence .
MM-Tag formulates the assignment of tags to sentences as a process of sequential decision making with an MDP in which each time step corresponds to a position in the sentence. The states, actions, transition function, value function, and policy function of the MDP are set as:
States : We design the state at time step as a pair , where and are the preliminary representation of the left context and right context of the input sentence of length at time step . is the prefix of the label sequence of length . At the beginning (), the state is initialized as , where is the empty sequence.
Actions : At each time step , the is the set of actions the agent can choose. That is, the action actually chooses a tag for word .
Transition function : is defined as
where appends and to and , respectively. And deletes from . At each time step , based on state the system chooses an action (tag) for the word position . Then, the system moves to time step and the system transits to a new state : first, the left context is updated by concatenating the next word while the right context is updated by deleting ; second, the system appends the selected tag to the end of , generating a new tag sequence.
Value function : The state value function is a scalar evaluation, predicting the accuracy of the tag assignments for the whole sentence (an episode), on the basis of the input state. The value function is learned so as to fit the real tag assignment accuracies of the training sentences.
In this paper, we use three LSTMs to respectively map the left context , right context and tag sequence in the state to two real vectors, and then define the value function as nonlinear transformation of the weighted sum of the LSTM’s outputs:
where and are the weight vector and the bias to be learned during training, is the nonlinear sigmoid function, and is a concatenation of the outputs from the word LSTM LSTM and tag LSTM LSTM:
The three LSTM networks are defined as follows: given , where is the word at -th position, represented with its word embedding. is the label at -th position, represented with one hot vector. LSTM outputs a representation for position :
where and are initialized with zero vector; operator “” denotes the element-wise product and “” is applied to each of the entries; the variables and denote the forget gate’s activation vector, input gate’s activation vector, output gate’s activation vector, cell state vector, and output vector of the LSTM block, respectively. are weight matrices and bias vectors need to be learned during training. The output vector and cell state vector at the -th cell are concatenated as the output of LSTM, that is
The function , which used to map the tag sequence into a real vector, is defined similarly to that of for LSTM.
Policy function : The policy defines a function that takes the state as input and output a distribution over all of the possible actions . Specifically, each probability in the distribution is a normalized soft-max function whose input is the bilinear product of the state representation in Equation (3) and the selected tag:
where is the one hot vector for representing the tag and is the parameter in bilinear product. The policy function is:
2.2. Strengthening raw policy with MCTS
Tagging directly with the predicted raw policy in Equation (4) may lead to suboptimal results because the policy is calculated based on the past tags. The raw policy has no idea about the tags that will be assigned for the future words. To alleviate the problem, following the practices in AlphaGo (Silver et al., 2016) and AlphaGo Zero (Silver et al., 2017), we propose to conduct lookahead search with MCTS. That is, at each position , an MCTS search is executed, guided by the policy function and the value function , and output a strengthened new search policy . Usually, the search policy has high probability to select a tag with higher accuracy than the raw policy defined in Equation (4).
Algorithm 1 shows the details of the MCTS in which each tree node corresponds to an MDP state. It takes a root node , value function and policy function as inputs. The algorithm iterates times and outputs a strengthened search policy for selecting a tag for the root node . Suppose that each edge (the edge from state to the state ) of the MCTS tree stores an action value , visit count , and prior probability . At each of the iteration, the MCTS executes the following steps:
Selection: Each iterations starts from the root state and iteratively selects the tags that maximize an upper confidence bound:
where is the tradeoff coefficient, and the bonus . is proportional to the prior probability but decays with repeated visits to encourage exploration.
Evaluation and expansion: When the traversal reaches a leaf node , the node is evaluated with the value function (Equation (2)). Note following the practices in AlphaGo Zero, we use the value function instead of rollouts for evaluating a node.
Then, the leaf node may be expanded. Each edge from the leaf position (corresponds to each action ) is initialized as: (Equation (4)), , and . In this paper all of the available actions of are expanded.
Back-propagation and update: At the end of evaluation, the action values and visit counts of all traversed edges are updated. For each edge , the prior probability is kept unchanged, and and are updated:
Calculate the strengthened search policy: Finally after iterating times, the strengthened search policy for the root node can be calculated according to the visit counts of the edges starting from :
for all .
2.3. Learning and inference algorithms
2.3.1. Reinforcement learning of the parameters
The model has parameters (including , and parameters in LSTM and LSTM) to learn. In the training phase, suppose we are given labeled sentence . Algorithm 2 shows the training procedure. First, the parameters is initialized to random weights in . At each subsequent iteration, for each tagged sentence , a tag sequence is predicted for with current parameter setting: at each position , an MCTS search is executed, using previous iteration of value function and policy function, and a tag is selected according to the search policy . The ranking terminates at the end of the sentence and achieved a predicted tag sequence . Given the ground truth tag sequence , the overall prediction accuracy of the sentence is calculated, denoted as . The data generated at each time step and the final evaluation are utilized as the signals in training for adjusting the value function. The model parameters are adjusted to minimize the error between the predicted value and the whole sentence accuracy , and to maximize the similarity of the policy to the search probabilities . Specifically, the parameters are adjusted by gradient descent on a loss function that sums over the mean-squared error and cross-entropy losses, respectively:
The model parameters are trained by back propagation and stochastic gradient descent. Specifically, we use AdaGrad (Duchi et al., 2011) on all parameters in the training process.
The inference of the tag sequence for a sentence is shown in Algorithm 3. Given a sentence , the system state is initialized as . Then, at each of the time steps , the agent receives the state and search the policy with MCTS, on the basis of the value function and policy function . Then, it chooses an action for the word at position . Moving to the next iteration , the state becomes . The process is repeated until the end of the sentence is reached.
We implemented the MM-Tag model based on TensorFlow and the code can be found at the Github repository http://hide_for_anonymous_review.
We tested the performances of MM-Tag on subsets of CoNLL 2000 chunking set111https://www.clips.uantwerpen.be/conll2000/chunking/. In chunking task, each word is tagged with its phrase type, e.g., tag “B-NP” indicates a word starting a noun phrase. Considering MM-Tag is time consuming for parsing the long sentences, we constructed a short sentence subset which was randomly selected from the whole CoNLL 2000 chunking set and the sentences longer than 13 words were removed. The final short sentence subset consists of 1000 sentences and the average sentence length is 9 words. Among them, 900 were used for training and 100 were used for testing. All of the words in the sentences were represented with the word embeddings. In the experiments, we used the publicly available GloVe 100-dimensional embeddings trained on 6 billion words from Wikipedia and Gigaword (Pennington et al., 2014).
We compare MM-Tag with linear statistical models of CRF implemented with an open software CRFsuite222http://www.chokkan.org/software/crfsuite/ (Okazaki, 2007) and neural models of LSTM-CRF and BI-LSTM-CRF, following the configurations in (Huang et al., 2015). For CRF, 935 spelling features and context features were extracted. The features including word identity, word suffix, word shape, word POS tag from current and nearby words etc.
For MM-Tag, the number of search times , the learning rate , the tree search trade-off parameter , and the number of hidden units in LSTM were set to , , , and .
Table 2 reports the performances of MM-Tag and baseline methods in terms of tagging precision, recall, F1, and accuracy. Boldface indicates the highest scores among all runs. From the result we can see that, MM-Tag outperformed all of the baseline methods in terms of all of the evaluation metrics, indicating the effectiveness of the proposed MM-Tag model. We note that neural methods (LSTM+CRF and BiLSTM+CRF) were underperformed the CRF and MM-Tag. The reason may because the short sentence subset is not sufficient enough to learn the large amount of parameters in neural networks.
In this paper we have proposed a novel approach to sequence tagging, referred to as MM-Tag. MM-Tag formalizes the tagging of a sentence as a sequence of decision-making with MDP. The lookahead MCTS is used to strengthen the raw predicted policy so that the search policy has high probability to select the correct for each word. Reinforcement learning is utilized to train the model parameters. MM-Tag enjoys several advantages: tagging with the shared policy and the value functions, end-to-end learning, and high accuracy in tagging. Experimental results show that MM-Tag outperformed the baselines of CRF, LSTM-CRF, and BI-LSTM-CRF.
- Duchi et al. (2011) John Duchi, Elad Hazan, and Yoram Singer. 2011. Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research 12, Jul (2011), 2121–2159.
- Feng et al. (2018) Jun Feng, Minlie Huang, Li Zhao, Yang Yang, and Xiaoyan Zhu. 2018. Reinforcement Learning for Relation Classification from Noisy Data. (2018).
- Huang et al. (2015) Zhiheng Huang, Wei Xu, and Kai Yu. 2015. Bidirectional LSTM-CRF models for sequence tagging. arXiv preprint arXiv:1508.01991 (2015).
- Lafferty et al. (2001) John Lafferty, Andrew McCallum, and Fernando CN Pereira. 2001. Conditional random fields: Probabilistic models for segmenting and labeling sequence data. (2001).
- Ma and Hovy (2016) Xuezhe Ma and Eduard Hovy. 2016. End-to-end sequence labeling via bi-directional lstm-cnns-crf. arXiv preprint arXiv:1603.01354 (2016).
- Maes et al. (2007) Francis Maes, Ludovic Denoyer, and Patrick Gallinari. 2007. Sequence labeling with reinforcement learning and ranking algorithms. In European Conference on Machine Learning. Springer, 648–657.
- McCallum et al. (2000) Andrew McCallum, Dayne Freitag, and Fernando CN Pereira. 2000. Maximum Entropy Markov Models for Information Extraction and Segmentation.. In Icml, Vol. 17. 591–598.
- Okazaki (2007) Naoaki Okazaki. 2007. CRFsuite: a fast implementation of Conditional Random Fields. (2007).
- Pennington et al. (2014) Jeffrey Pennington, Richard Socher, and Christopher Manning. 2014. Glove: Global vectors for word representation. In EMNLP 2014. 1532–1543.
- Ratnaparkhi (1996) Adwait Ratnaparkhi. 1996. A maximum entropy model for part-of-speech tagging. In Conference on Empirical Methods in Natural Language Processing.
- Silver et al. (2016) David Silver, Aja Huang, Chris J Maddison, Arthur Guez, Laurent Sifre, George Van Den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, et al. 2016. Mastering the game of Go with deep neural networks and tree search. nature 529, 7587 (2016), 484–489.
- Silver et al. (2017) David Silver, Julian Schrittwieser, Karen Simonyan, Ioannis Antonoglou, Aja Huang, Arthur Guez, Thomas Hubert, Lucas Baker, Matthew Lai, Adrian Bolton, et al. 2017. Mastering the game of go without human knowledge. Nature 550, 7676 (2017), 354.