Decoding Sequence Models: RNNs, LSTMs, and the Rise of Transformers
ChatGPT and its Cultural Relevance, circa 2024
ChatGPT took the world by storm in 2024, we saw users from all over use the tool for many tasks. All social media was flooded with how people were using it to improve their own workflows, get ahead and the expected fear mongering that comes with new tech. ChatGPT is a Generative Pretrained Transformer (GPT) wrapped into a nice chat interface. The foundational goals for such a model -
- one that is able to perform tasks non-trivial for humans
- we’ve all seen AI say a cat is a bat and a dog a rat
- contextualizing conversations and remembering context
- we don’t want to relate an event every time we chat about it
- general purpose and trainable to do any describable tasks
- an assistant who’s freely available and can “learn” were laid in a groundbreaking paper from 2014 - Attention is All you Need by Vaswani et al.
Here we’ll talk about what really changed in 2014 and the key ideas discussed in the paper. This may be a series, who knows.
Sequence Transduction Models - circa 2014
Sequence transduction models are models that take in one input sequence and output another sequence. Translation, summarisation, speech recognition, all fall into this category of models. These are all very important applications that were being looked at long before LLMs arrived. When one moves from a single input/single output to a S2S (sequence to sequence) model - What really changes?
- Your inputs can be very short or very long.
- Your inputs can have related bits anywhere in the sequence.
What does one expect from such a model?
- Work with inputs of arbitrary lengths and return outputs of arbitrary lengths
- Recognise relationships within a sequence and respond to the entire sequence
- be “effective” while doing the above
So what was state of the art for sequence transduction models in 2014? RNNs (Recurrent Neural Network) and the vastly more improved variety of RNNs - LSTMs (Long Short Term Memory) were all over the AI world. Andrej Karpathy has a wonderful blog from 2015 emphasizing how groundbreaking RNNs were. https://karpathy.github.io/2015/05/21/rnn-effectiveness/. This, today, reads like a historical note.
The key idea for an RNN was to have memory (hidden states) of previous tokens in the sequence. That is how RNNs recognise relationships and respond to the sequence as a whole. Tokens are how sequences are broken down into bits. In English, we break sentences down into words with spaces. Tokens in machine learning are similar to words but not exactly as they are chunks optimised for a model’s understanding. Tokens are important for RNNs and for all kinds of ML compute.
Now, when an RNN sees an input sequence (I have 2 cats and want one more.), it breaks it down into tokens. We won’t pretend to know what this looks like but it will be a vector of arbitrary length like [w1, w2, w3 … wn]. Since the model is looking at a fresh sequence, the hidden state is h0 = 0 (or some other initialization, this whole article is a gross oversimplification). The RNN is now going to go over each of the n tokens and calculate a new hidden state for each which depends on the current token and the previous hidden state. hn = f(wn, hn-1)$
This is a classic recurrence relation. Here, now hn-1 further depends on hn-2 and so on. so to calculate hn – we must calculate n states - right from h1 to hn. Hence, this means - the longer a sequence, the longer it takes for the RNN to compute hidden states.
The “encoder” of the RNN is responsible for calculating these hidden states. The output is the last hidden state hn which depends on all tokens and contains the entire context of the sequence.
But where does the output come from?
The “decoder” of the RNN uses the input sequence “context” generated by the encoder to then give us a sequence. The decoder is a second auto-regressive RNN. Autoregressive means every subsequent output token depends on all the previous tokens. To start with, the decoder only has hn and the previous outputs - []. This is empty.
\[h_{d0} = h_n\] \[y_{1},\ h_{d1} = f(h_{d0},\ [])\] \[y_{2},\ h_{d2} = f(h_{d1},\ [y_{1}])\]\(y_{3},\ h_{d3} = f(h_{d2},\ [y_{1},\ y_{2}])\) and so on.
the final output of the model is [y1, y2, y3, … yk]. This doesnt look like a sensible output but it could read like -
Me - I have 2 cats and want one more. RNN - You’ll have three cats then.
Now this looks like a theoretically perfect system. However, there were lots of practical challenges with RNNs.
One is that with longer sequences - hn just cant hold enough context for the decoder to generate valuable responses. There is a problem of calculating relatedness between tokens thar are far apart.
Second, hn takes longer to calculate with longer sequences.
Third, this calculation HAS to be done sequentially. There is no way to paralellize this. And in recent years, we’ve been using GPUs which are great at performing parallel tasks to optimise all kinds of compute - ml and otherwise. To make RNNs both faster and more GPU friendly means parallelization.
An improvement over the basic RNN was the LSTM – this is an RNN that can decide what information to store and what to toss. It can look at longer sequences but eventually runs into the same problems.
LSTMs were very successful (even commercially) as single purpose models where one expected a model to be trained on very specific data and then to be used to predict the same. However, the GPT we know so well now was well underway.
GPT - What’s that?
GPT - Generative Pretrained Transformer is a Transformer offshoot. Transformers themselves were demonstrated by Vaswani et al. for language translation tasks. Similar to the RNN described above, the transformer also have an encoder and decoder stack. Even in terms of responsibilities - the encoder understands the input and the decoder produces outputs in an autoregressive manner. What they got rid of is recurrence. Recurrence is what makes this slow and expensive. Recurrence is what makes this sequential. Now, recurrence is what makes relationships between two distant tokens possible in RNNs. How else could we still recognise relationships but do it non sequentially?
What this paper suggests is we represent the sequence as an embedding vector and calculating how much each embedding “attends” to every other embedding in the input. “Attending” to something is an arbitrary non commutative function which will be discussed later with all its nitty gritties. What is key here is that it is a one shot calculation for the entire sequence.
Does this do the things we were using recurrence for?
This does preserve relationships between tokens.
What does it do better than recurrence? This ensures that the computation is more or less constant for all sequence lengths. It is better at calculating long range dependencies, 0 and nth row are compared the same way as (n-1)th and nth row. We can now parallelize and really tap into GPU power.
To summarise, we’ve been trying to create non trivial sequence to sequence models for a while and GPT represents the first widespread success to break out of research and industry circles.