Predictive Coding for Natural Language Representation

Vladimir Araujo and Marie-Francine Moens

Pre-trained neural language models are among the leading methods to learn useful representations for textual data. Several pre-training objectives have been proposed in recent years to generate representations at the word or sentence level, such as causal language modeling (Radford et al., 2018, 2019), masked language modeling (Devlin et al., 2019), and permutation language modeling (Yang et al., 2019). However, these approaches do not produce suitable representations at the discourse level (Huber et al., 2020). Simultaneously, neuroscience studies have suggested that predictive coding (PC) plays an essential role in human language development (Ylinen et al., 2016; Zettersten, 2019). PC postulates that the brain continually makes predictions of incoming sensory stimuli (Rao and Ballard, 1999; Friston, 2005; Clark, 2013; Hohwy, 2013), with word prediction being the primary mechanism (Berkum et al., 2005; Kuperberg and Jaeger, 2015). However, recent studies speculate that the predictive process could occur within and across utterances, fostering discourse comprehension (Kandylaki et al., 2016; Pickering and Gambi, 2018). This work extends BERT-type models with recursive bottom-up and top-down computation based on PC theory. Specifically, we incorporate top-down connections that, according to PC, convey predictions from upper to lower layers, which are contrasted with bottom-up representations to generate an error signal that guides the model’s optimization. We attempt to build feature representations that capture discourse-level relationships by continually predicting future sentences in a latent space using this approach. We evaluate our method on DiscoEval (Chen et al., 2019) and SciDTB for discourse evaluation (Huber et al., 2020) to assess whether the embeddings produced by our model capture discourse properties of sentences without finetuning. Our approach improves performance in 6 out of 11 tasks by excelling in discourse relationship detection.