The Stanford Question Answering Dataset (SQuAD) contains 100,000+ question-answer pairs extracted from 500+ Wikipedia articles, where each answer is a substring in the given article. Given a question (Q) about a specific paragraph/context (C), the goal is to identify the correct answer span (A). At time of writing, both humans and state-of-the-art models can easily achieve 80% exact-match (EM) accuracy on this task.
A typical setup for SOTA models fuses a one-hot character embedding with a standard word embedding (i.e. GloVe), encodes both the question and context with one or more recurrent layer, applies attention between the question and context (and sometimes vice-versa), applies another set of recurrent layers, and then produces two probability distributions over the paragraph for the start and end locations respectively.
In our model, we deliberately choose to operate at the byte-level rather than using pre-trained word embeddings. Since this model will eventually be adapted to several different domains, out-of-vocabulary words are a significant concern and retrofitting the word embeddings for each domain is unappealing.
Our code is available on Github, pretrained models with the appropriate bootstrapping code is available under
For our experiments, we use Kullback–Leibler divergence instead of cross-entropy. This allows us to make learning smoother by representing the start/end position as Gaussian distributions, as opposed to a hard 1/0 binary label required by categorical cross-entropy.
Furthermore, since we are given multiple correct answer spans for each question, we only take into account the minimum loss value among the correct answer spans. Therefore, each batch of size 64 contains exactly 64 questions, where the loss value for each question is the minimum loss across different possible answers.
In all of our intermediate models, we use the Adam optimizer for faster convergence; however, in our final model, we use SGD + momentum as it tends to produce better results on the development dataset.
We start by building a baseline model where the question and context are represented by a one-hot ASCII encoding, the
mixer components are composed of a stack of "valid" convolution layers, and the attention module is a 2-headed attention mechanism which is summed over the question.
The code for our text encoding scheme is shown below. Note that this representation is very inefficient - many ASCII characters are non-printing and therefore unlikely to appear in the corpus - and is also lossy as it maps all non-ASCII characters to an arbitrary dimension.
import torch import functools @functools.lru_cache() def byte_vector(text): """Transform a string into a one-hot embedding. Arguments: text (string): the string to encode Returns: Tensor: tensor of size (256, length) """ vec = torch.zeros(256, len(text)) for i, c in enumerate(map(ord, text)): vec[min(255, c),i] = 1.0 return vec
An alternative implementation incorporates positional information by concatenating a set of sin waves at different frequencies as suggested by Vaswani et al., 2017, but we observe no change in model performance and omit it here for simplicity.
For both the encoder and mixer modules, we use a stack of 5 convolution layers with instance normalization in between the various convolution layers. Each convolution layer has filter width 11, giving the encoder module a receptive field of 55 characters, and the final output a receptive field of 110 characters.
As shown above, this gives us an extremely disappointing EM accuracy of ~30% on the development set.
Next, we draw inspiration from DenseNet (Huang et al. 2016) and connect each convolution layer to the following convolution layers. We expect this to allow gradients to flow much more smoothly through the network, resulting in faster convergence.
For fair comparison, we ensure our densenet-inspired model has the same number of parameters.
# relu/normalization/dropout not shown h1 = self.conv1(text_vec) h2 = self.conv2(torch.cat([text_vec, h1], dim=1)) h3 = self.conv3(torch.cat([text_vec, h2, h1], dim=1)) h4 = self.conv4(torch.cat([text_vec, h3, h2, h1], dim=1)) h5 = self.conv5(torch.cat([text_vec, h4, h3, h2, h1], dim=1)) return torch.cat([h5, h4, h3, h2, h1], dim=1)
Visualizing the Loss Landscape of Neural Nets (Li et al., 2016) suggests that these connections provide a smoother loss landscape, and we find that introducing this connectivity pattern results in a ~15% increase in EM accuracy for the same number of model parameters.
This brings us up to 40% EM accuracy which is disappointing, but is still a big step in the right direction.
recurrent vs convolutional layers
Historically, 1d convolutions have provided a compelling alternative to recurrent neural networks. Convolutional layers are fully parallelizable as they don't need to process the input sequentially and there are a variety of mathematical tricks (i.e. Fourier transforms) for faster GPU inference/training.
Recent neural machine translation (NMT) literature has shown that recurrent layers can be augmented or even replaced by (self-)attention mechanisms. Since we operate at the byte-level (as opposed to word-level), purely attention-based approaches such as those highlighted in Attention Is All You Need (Vaswani et al., 2017) are not feasible.
We experiment with both convolutional and recurrent layers in the
mixer components of our model and observe a 5-10% increase in performance when using recurrent layers. The below plot shows the EM accuracy of our recurrent model on our dev set:
A large part of this discrepancy can be attributed answer spans which are located more than 100 characters from the context (i.e. "The island next to... (more than 100 characters)... is named
Furthermore, we observe that the
precision@N metrics1 for N > 1 with recurrent layers are significantly better; if we add up the precision metrics from N=1 to N=5 (as shown in the stacked bar chart), we see approximately a 40% increase in performance between the convolutional and recurrent models. In other words, for the recurrent model, the "true" answer is in the top 5 predicted answer spans around 90% of the time.
1 Note that the
p@N metrics indicate the percent of correct answers ranked in the Nth position by the model.
- Evaluate fusing GloVe vectors with the byte-level embeddings.
- Pre-training convolutional encoder to reproduce GloVe embeddings (so that GloVe is not needed at test time)
- Curriculum learning with a set of losses of varying difficulties.
- Mixture of experts models.