What is the transformer machine learning model?

8 min read
transformer neural network

This article is part of Demystifying AI, a series of posts that (try to) disambiguate the jargon and myths surrounding AI. (In partnership with Paperspace)

In recent years, the transformer model has become one of the main highlights of advances in deep learning and deep neural networks. It is mainly used for advanced applications in natural language processing. Google is using it to enhance its search engine results. OpenAI has used transformers to create its famous GPT-2 and GPT-3 models.

Since its debut in 2017, the transformer architecture has evolved and branched out into many different variants, expanding beyond language tasks into other areas. They have been used for time series forecasting. They are the key innovation behind AlphaFold, DeepMind’s protein structure prediction model. Codex, OpenAI’s source code–generation model, is based on transformers. More recently, transformers have found their way into computer vision, where they are slowly replacing convolutional neural networks (CNN) in many complicated tasks.

Researchers are still exploring ways to improve transformers and use them in new applications. Here is a brief explainer about what makes transformers exciting and how they work.

Processing sequences with neural networks

feedforward neural net vs recurrent neural network

The classic feed-forward neural network is not designed to keep track of sequential data and maps each input into an output. This works for tasks such as classifying images but fails on sequential data such as text. A machine learning model that processes text must not only compute every word but also take into consideration how words come in sequences and relate to each other. The meaning of words can change depending on other words that come before and after them in the sentence.

Before transformers, recurrent neural networks (RNN) were the go-to solution for natural language processing. When provided with a sequence of words, an RNN processes the first word and feeds back the result into the layer that processes the next word. This enables it to keep track of the entire sentence instead of processing each word separately.

Recurrent neural nets had disadvantages that limited their usefulness. First, they were very slow. Since they had to process data sequentially, they could not take advantage of parallel computing hardware and graphics processing units (GPU) in training and inference. Second, they could not handle long sequences of text. As the RNN got deeper into a text excerpt, the effects of the first words of the sentence gradually faded. This problem, known as “vanishing gradients,” was problematic when two linked words were very far apart in the text. And third, they only captured the relations between a word and the words that came before it. In reality, the meaning of words depends on the words that come both before and after them.

Long short-term memory (LSTM) networks, the successor to RNNs, were able to solve the vanishing gradients problem to some degree and were able to handle larger sequences of text. But LSTMs were even slower to train than RNNs and still couldn’t take full advantage of parallel computing. They still relied on the serial processing of text sequences.

Transformers, introduced in the 2017 paper “Attention Is All You Need,” made two key contributions. First, they made it possible to process entire sequences in parallel, making it possible to scale the speed and capacity of sequential deep learning models to unprecedented rates. And second, they introduced “attention mechanisms” that made it possible to track the relations between words across very long text sequences in both forward and reverse directions.

Processing sequences with neural networks

rnn types

Before we discuss how the transformer model works, it is worth discussing the types of problems that sequential neural networks solve.

A “vector to sequence” model takes a single input, such as an image, and produces a sequence of data, such as a description.

A “sequence to vector” model takes a sequence as input, such as a product review or a social media post, and outputs a single value, such as a sentiment score.

A “sequence to sequence” model takes a sequence as input, such as an English sentence, and outputs another sequence, such as the French translation of the sentence.

Despite their differences, all these types of models have one thing in common. They learn representations. The job of a neural network is to transform one type of data into another. During training, the hidden layers of the neural network (the layers that stand between the input and output) tune their parameters in a way that best represents the features of the input data type and maps it to the output.

The original transformer was designed as a sequence-to-sequence (seq2seq) model for machine translation (of course, seq2seq models are not limited to translation tasks). It is composed of an encoder module that compresses an input string from the source language into a vector that represents the words and their relations to each other. The decoder module transforms the encoded vector into a string of text in the destination language.

Tokens and embeddings

transformer tokens

The input text must be processed and transformed into a unified format before being fed to the transformer. First, the text goes through a “tokenizer,” which breaks it down into chunks of characters that can be processed separately. The tokenization algorithm can depend on the application. In most cases, every word and punctuation mark roughly counts as one token. Some suffixes and prefixes count as separate tokens (e.g., “ize,” “ly,” and “pre”). The tokenizer produces a list of numbers that represent the token IDs of the input text.

The tokens are then converted into “word embeddings.” A word embedding is a vector that tries to capture the value of words in a multi-dimensional space. For example, the words “cat” and “dog” can have similar values across some dimensions because they are both used in sentences that are about animals and house pets. However, “cat” is closer to “lion” than “wolf” across some other dimension that separates felines from canids. Similarly, “Paris” and “London” might be close to each other because they are both cities. However, “London” is closer to “England” and “Paris” to “France” on a dimension that separates countries. Word embeddings usually have hundreds of dimensions.  

Word embeddings are created by embedding models, which are trained separately from the transformer. There are several pre-trained embedding models that are used for language tasks.

Attention layers

transformer architecture

Once the sentence is transformed into a list of word embeddings, it is fed into the transformer’s encoder module. Unlike RNN and LSTM models, the transformer does not receive one input at a time. It can receive an entire sentence’s worth of embedding values and process them in parallel. This makes transformers more compute-efficient than their predecessors and also enables them to examine the context of the text in both forward and backward sequences.

To preserve the sequential nature of the words in the sentence, the transformer applies “positional encoding,” which basically means that it modifies the values of each embedding vector to represent its location in the text.

Next, the input is passed to the first encoder block, which processes it through an “attention layer.” The attention layer tries to capture the relations between the words in the sentence. For example, consider the sentence “The big black cat crossed the road after it dropped a bottle on its side.” Here, the model must associate “it” with “cat” and “its” with “bottle.” Accordingly, it should establish other associations such as “big” and “cat” or “crossed” and “cat.” Otherwise put, the attention layer receives a list of word embeddings that represent the values of individual words and produces a list of vectors that represent both individual words and their relations to each other. The attention layer contains multiple “attention heads,” each of which can capture different kinds of relations between words.

The output of the attention layer is fed to a feed-forward neural network that transforms it into a vector representation and sends it to the next attention layer. Transformers contain several blocks of attention and feed-forward layers to gradually capture more complicated relationships.

The task of the decoder module is to translate the encoder’s attention vector into the output data (e.g., the translated version of the input text). During the training phase, the decoder has access both to the attention vector produced by the encoder and the expected outcome (e.g., the translated string).

The decoder uses the same tokenization, word embedding, and attention mechanism to process the expected outcome and create attention vectors. It then passes this attention vector and the attention layer in the encoder module, which establishes relations between the input and output values. In the translation application, this is the part where the words from the source and destination languages are mapped to each other. Like the encoder module, the decoder attention vector is passed through a feed-forward layer. Its result is then mapped to a very large vector which is the size of the target data (in the case of language translation, this can span across tens of thousands of words).

Training the transformer

large language models

During training, the transformer is provided with a very large corpus of paired examples (e.g., English sentences and their corresponding French translations). The encoder module receives and processes the full input string. The decoder, however, receives a masked version of the output string, one word at a time, and tries to establish the mappings between the encoded attention vector and the expected outcome. The encoder tries to predict the next word and makes corrections based on the difference between its output and the expected outcome. This feedback enables the transformer to modify the parameters of the encoder and decoder and gradually create the right mappings between the input and output languages.

The more training data and parameters the transformer has, the more capacity it gains to maintain coherence and consistency across long sequences of text.

Variations of the transformer

In the machine translation example that we examined above, the encoder module of the transformer learned the relations between English words and sentences, and the decoder learns the mappings between English and French.

But not all transformer applications require both the encoder and decoder module. For example, the GPT family of large language models uses stacks of decoder modules to generate text. BERT, another variation of the transformer model developed by researchers at Google, only uses encoder modules.

The advantage of some of these architectures is that they can be trained through self-supervised learning or unsupervised methods. BERT, for example, does much of its training by taking large corpora of unlabeled text, masking parts of it, and trying to predict the missing parts. It then tunes its parameters based on how much its predictions were close to or far from the actual data. By continuously going through this process, BERT captures the statistical relations between different words in different contexts. After this pretraining phase, BERT can be finetuned for a downstream task such as question answering, text summarization, or sentiment analysis by training it on a small number of labeled examples.

Using unsupervised and self-supervised pretraining reduces the manual effort required to annotate training data.

A lot more can be said about transformers and the new applications they are unlocking, which is out of the scope of this article. Researchers are still finding ways to squeeze more out of transformers.

Transformers have also created discussions about language understanding and artificial general intelligence. What is clear is that transformers, like other neural networks, are statistical models that capture regularities in data in clever and complicated ways. They do not “understand” language in the way that humans do. But they are exciting and useful nonetheless and have a lot to offer.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.