Training a Transformer
Saw an article about "Attention" on X, and I quickly realized something—I had never actually implemented it myself.
Then I thought, "Why not build a transformer from scratch and train it?" Just for fun, just to learn.
In the original paper, the authors trained their transformer on a translation task. That felt like the perfect place to start.
Chapter 1: The Basics
Before we begin implementing the paper, we need to have a solid understanding of the concept of Attention. Let's cover some basics first.
Keep in mind, it helps to have a basic understanding of Recurrent Neural Networks and a bit of Machine Learning to get the most out of this article/blog.
Chapter 1.1: The Basics - Why Transformers ?
Prior to transformers, Recurrent Neural Networks (RNNs) were the dominant approach for sequence modelling tasks. RNNs use an encoder-decoder architecture, where the encoder processes the input sequence and generates hidden states, and the decoder uses those hidden states to generate the output sequence.
Lets take an example sentence - "I love to eat bread." to quickly demonstrate how the RNN's actually work.
Encoding:
Sai → Word Vector → Encoder(h₀, Word Vector) → h₁
loves → Word Vector → Encoder(h₁, Word Vector) → h₂
to → Word Vector → Encoder(h₂, Word Vector) → h₃
watch → Word Vector → Encoder(h3, Word Vector) → h4
football→ Word Vector → Encoder(h4, Word Vector) → h5
The final hidden state h5 is used as the context or summary of the entire input sentence. This final state which passed to the decoder is often biased by the inputs processed later than in the beginning.
Decoding:
<START> → Word Vector → Decoder(h₅, Word Vector) → h₆ → Output: "Sai"
"Sai" → Word Vector → Decoder(h₆, Word Vector) → h₇ → Output: "aime"
"aime" → Word Vector → Decoder(h₇, Word Vector) → h₈ → Output: "regarder"
"regarder"→ Word Vector → Decoder(h₈, Word Vector) → h₉ → Output: "le"
"le" → Word Vector → Decoder(h₉, Word Vector) → h₁₀ → Output: "football"
"football" → Word Vector → Decoder(h₁₀, Word Vector) → h₁₁ → Output: <END>
Now as the sentence keeps getting longer, the contribution of the first word in the final hidden state of the encoder is significantly diminished which leads to the encoder "forgetting" the word and impacts the translation accuracy.
This issue occurs due to the vanishing (or exploding) gradients problem which makes it difficult for the RNN to retain and propagate information over long sentence.
In addition to this, RNNs are also very very slow to train. Since they process inputs sequentially—one step at a time—they are difficult to parallelize, making them inefficient for large-scale training.
These were the major reasons why people started looking for alternatives to vanilla RNN's. While architectures like Bidirectional RNNs, LSTMs, and GRUs were proposed to address some of these issues, none of them were as successful as the transformers.
Chapter 1.2: The Basics - Attention.
The key innovation that made transformers so successful is a mechanism called "Attention" . Like many others, I was first introduced to this concept through the landmark paper “Attention Is All You Need” by Vaswani et al. (2017). Considered as the holy grail of AI by many, it completely changed the entire landscape of AI.
Chapter 2: Understanding what happens inside the transformer
In this chapter, we will take a deep dive into the inner workings of a transformer model, breaking it down step by step. Throughout this process, we will use some concrete examples to illustrate each step, making the it easier to understand.
Chapter 2.1: Tokenization
Before the transformer can process text, it needs to convert words into numerical representations called tokens. Think of a tokenizer as a translator that converts text into a numerical language that neural networks can process.
The process of tokenization involves two key steps:
Splitting Text
- Some tokenizers split on spaces, treating each word as a separate token.
- Others break words into smaller subword units, which helps with handling rare or unseen words.
Converting into IDs
- Once the text is split, each unique word or subword is assigned a unique ID in a dictionary, known as the vocabulary.
- This vocabulary is used to map each word or subword to its corresponding integer ID for further processing in the model.
Let's walk through an example to understand the process of tokenization.
Given the sentence: "Sai likes to watch football"
Step 1: Splitting Text
- The text is split into individual words (tokens):
["sai", "likes", "to", "watch", "football"]
Alternatively, using subword tokenization (if the tokenizer breaks words into smaller units), we might get:
["s", "ai", "likes", "to", "watch", "football"]
Step 2: Converting into IDs
- After splitting the whole corpus of data into words/subwords, we form a vocabulary that might look like this:
| Word | ID |
|---|---|
| s | 1 |
| ai | 2 |
| watch | 3 |
| to | 4 |
| likes | 5 |
| football | 6 |
For the sentence "Sai likes to watch football", the token IDs would be:
[1, 2, 5, 4, 3, 6]
Chapter 2.2: Word Embeddings
Now that we have these token IDs, they don't carry any inherent meaning by themselves. Just having a sequence of numbers like [1, 2, 5, 4, 3, 6] doesn't tell us anything about the relationships between the words in the sentence. To capture the semantic meaning of these words and how they relate to each other, we need more sophisticated representations.
This is where embeddings come in. Word embeddings map each token/word to high dimensional vector
Chapter 3: Coding the transformer
Now that we have a solid understanding of transformers, let's code one from scratch. We'll follow the same flow as before, starting with tokenizers. First, let's take a closer look at them.
Chapter x.1: The Tokenizer
As seen in the previous chapter, a tokenizer converts text into input IDs, which are numerical representations of words or subwords. This step is essential before feeding data into a transformer model. There are two things you could do:
Build a custom tokenizer tailored to your dataset using HuggingFace's tokenizers library.
Use a pre-trained tokenizer.
We will use the Byte Pair Encoding (BPE) algorithm, Here’s how you can create a tokenizer using Hugging Face’s tokenizers library:
def build_custom_tokenizer(ds, lang, tokenizer_path):
# Initialize a BPE tokenizer with an unknown token
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
if lang == "en_text":
tokenizer.normalizer = Sequence([
StripAccents(), # Remove diacritics (e.g., café → cafe)
Lowercase() # Convert to lowercase
])
else:
tokenizer.normalizer = Sequence([
Lowercase()
])
tokenizer.pre_tokenizer = PreSequence([WhitespaceSplit(), Punctuation()])
# Configure BPE training parameters
trainer = BpeTrainer(
special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"],
min_frequency=3, # Ignore tokens that appear less than 3 times
vocab_size=60000 # Set vocabulary size
)
# Train the tokenizer on the provided dataset
tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
# Save the trained tokenizer to disk
tokenizer.save(str(tokenizer_path))
You can fine-tune the min_frequency and vocab_size parameters depending on how complex you want your tokenizer to be. A good rule of thumb is to increase min_frequency until your vocabulary size fits comfortably under the vocab_size limit. This approach prevents your model from becoming too large and resource-intensive to train.
Training a custom tokenizer can be highly beneficial when developing models on specialized datasets, such as medical data or other domain-specific content. A custom tokenizer allows you to account for terms and phrases that are uncommon in general datasets but frequently appear in your specific data, improving the model's understanding and performance within that domain.
Pretrained Tokenizer
def use_custom_tokenizer(tokenizer_name, cache_dir="tokenizers"):
print(f"Loading pre-trained tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
class TokenizerWrapper:
def __init__(self, hf_tokenizer):
self.tokenizer = hf_tokenizer
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.sep_token_id
self.sos_token_id = tokenizer.cls_token_id
self.unk_token_id = tokenizer.unk_token_id
def encode(self, text):
result = self.tokenizer(
text, add_special_tokens=False, return_attention_mask=False
)
return result["input_ids"]
def decode(self, ids, skip_special_tokens=False):
return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
def get_vocab_size(self):
return len(self.tokenizer)
return TokenizerWrapper(tokenizer)
When using pretrained tokenizers from HuggingFace, be mindful of special tokens, as they vary across architectures. For example, if you're using a pretrained tokenizer from a BERT-based model, it doesn't have SOS and EOS tokens. Instead, it uses CLS and SEP tokens, which serve similar roles as SOS and EOS.
Using a pretrained tokenizer like tiktoken (used by OpenAI models) or SentencePiece often gives better results for generic data compared to training a tokenizer from scratch.
These input ids