• This is a standalone notebook implementing the popular byte pair encoding (BPE) tokenization algorithm, which is used in models like GPT-2 to GPT-4, Llama 3, etc., from scratch for educational purposes
  • For more details about the purpose of tokenization, please refer to Chapter 2; this code here is bonus material explaining the BPE algorithm
  • The original BPE tokenizer that OpenAI implemented for training the original GPT models can be found here
  • The BPE algorithm was originally described in 1994: “A New Algorithm for Data Compression” by Philip Gage
  • Most projects, including Llama 3, nowadays use OpenAI’s open-source tiktoken library due to its computational performance; it allows loading pretrained GPT-2 and GPT-4 tokenizers, for example (the Llama 3 models were trained using the GPT-4 tokenizer as well)
  • The difference between the implementations above and my implementation in this notebook, besides it being is that it also includes a function for training the tokenizer (for educational purposes)
  • There’s also an implementation called minBPE with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to minbpe my implementation additionally allows loading the original OpenAI tokenizer vocabulary and merges

 

1. The main idea behind byte pair encoding (BPE)

  • The main idea in BPE is to convert text into an integer representation (token IDs) for LLM training (see Chapter 2)

 

1.1 Bits and bytes

  • Before getting to the BPE algorithm, let’s introduce the notion of bytes
  • Consider converting text into a byte array (BPE stands for “byte” pair encoding after all):
text = "This is some text"
byte_ary = bytearray(text, "utf-8")
print(byte_ary)
bytearray(b'This is some text')
  • When we call list() on a bytearray object, each byte is treated as an individual element, and the result is a list of integers corresponding to the byte values:
ids = list(byte_ary)
print(ids)
[84, 104, 105, 115, 32, 105, 115, 32, 115, 111, 109, 101, 32, 116, 101, 120, 116]
  • This would be a valid way to convert text into a token ID representation that we need for the embedding layer of an LLM
  • However, the downside of this approach is that it is creating one ID for each character (that’s a lot of IDs for a short text!)
  • I.e., this means for a 17-character input text, we have to use 17 token IDs as input to the LLM:
print("Number of characters:", len(text))
print("Number of token IDs:", len(ids))
Number of characters: 17
Number of token IDs: 17
  • If you have worked with LLMs before, you may know that the BPE tokenizers have a vocabulary where we have a token ID for whole words or subwords instead of each character
  • For example, the GPT-2 tokenizer tokenizes the same text (“This is some text”) into only 4 instead of 17 tokens: 1212, 318, 617, 2420
  • You can double-check this using the interactive tiktoken app or the tiktoken library:

import tiktoken

gpt2_tokenizer = tiktoken.get_encoding("gpt2")
gpt2_tokenizer.encode("This is some text")
# prints [1212, 318, 617, 2420]
  • Since a byte consists of 8 bits, there are 28 = 256 possible values that a single byte can represent, ranging from 0 to 255
  • You can confirm this by executing the code bytearray(range(0, 257)), which will warn you that ValueError: byte must be in range(0, 256))
  • A BPE tokenizer usually uses these 256 values as its first 256 single-character tokens; one could visually check this by running the following code:
import tiktoken
gpt2_tokenizer = tiktoken.get_encoding("gpt2")

for i in range(300):
    decoded = gpt2_tokenizer.decode([i])
    print(f"{i}: {decoded}")
"""
prints:
0: !
1: "
2: #
...
255: �  # <---- single character tokens up to here
256:  t
257:  a
...
298: ent
299:  n
"""
  • Above, note that entries 256 and 257 are not single-character values but double-character values (a whitespace + a letter), which is a little shortcoming of the original GPT-2 BPE Tokenizer (this has been improved in the GPT-4 tokenizer)

 

1.2 Building the vocabulary

  • The goal of the BPE tokenization algorithm is to build a vocabulary of commonly occurring subwords like 298: ent (which can be found in entangle, entertain, enter, entrance, entity, …, for example), or even complete words like
318: is
617: some
1212: This
2420: text
  • The BPE algorithm was originally described in 1994: “A New Algorithm for Data Compression” by Philip Gage
  • Before we get to the actual code implementation, the form that is used for LLM tokenizers today can be summarized as follows:

 

1.3 BPE algorithm outline

1. Identify frequent pairs

  • In each iteration, scan the text to find the most commonly occurring pair of bytes (or characters)

2. Replace and record

  • Replace that pair with a new placeholder ID (one not already in use, e.g., if we start with 0…255, the first placeholder would be 256)
  • Record this mapping in a lookup table
  • The size of the lookup table is a hyperparameter, also called “vocabulary size” (for GPT-2, that’s 50,257)

3. Repeat until no gains

  • Keep repeating steps 1 and 2, continually merging the most frequent pairs
  • Stop when no further compression is possible (e.g., no pair occurs more than once)

Decompression (decoding)

  • To restore the original text, reverse the process by substituting each ID with its corresponding pair, using the lookup table

 

1.4 BPE algorithm example

1.4.1 Concrete example of the encoding part (steps 1 & 2)

  • Suppose we have the text (training dataset) the cat in the hat from which we want to build the vocabulary for a BPE tokenizer

Iteration 1

  1. Identify frequent pairs
    • In this text, “th” appears twice (at the beginning and before the second “e”)
  2. Replace and record
    • replace “th” with a new token ID that is not already in use, e.g., 256
    • the new text is: <256>e cat in <256>e hat
    • the new vocabulary is
  0: ...
  ...
  256: "th"

Iteration 2

  1. Identify frequent pairs
    • In the text <256>e cat in <256>e hat, the pair <256>e appears twice
  2. Replace and record
    • replace <256>e with a new token ID that is not already in use, for example, 257.
    • The new text is:
      <257> cat in <257> hat
      
    • The updated vocabulary is:
      0: ...
      ...
      256: "th"
      257: "<256>e"
      

Iteration 3

  1. Identify frequent pairs
    • In the text <257> cat in <257> hat, the pair <257> appears twice (once at the beginning and once before “hat”).
  2. Replace and record
    • replace <257> with a new token ID that is not already in use, for example, 258.
    • the new text is:
      <258>cat in <258>hat
      
    • The updated vocabulary is:
      0: ...
      ...
      256: "th"
      257: "<256>e"
      258: "<257> "
      
  • and so forth

 

1.4.2 Concrete example of the decoding part (steps 3)

  • To restore the original text, we reverse the process by substituting each token ID with its corresponding pair in the reverse order they were introduced
  • Start with the final compressed text: <258>cat in <258>hat
  • Substitute <258><257> : <257> cat in <257> hat
  • Substitute <257><256>e: <256>e cat in <256>e hat
  • Substitute <256> → “th”: the cat in the hat

 

2. A simple BPE implementation

  • Below is an implementation of this algorithm described above as a Python class that mimics the tiktoken Python user interface
  • Note that the encoding part above describes the original training step via train(); however, the encode() method works similarly (although it looks a bit more complicated because of the special token handling):
  1. Split the input text into individual bytes
  2. Repeatedly find & replace (merge) adjacent tokens (pairs) when they match any pair in the learned BPE merges (from highest to lowest “rank,” i.e., in the order they were learned)
  3. Continue merging until no more merges can be applied
  4. The final list of token IDs is the encoded output
from collections import Counter, deque
from functools import lru_cache
import json


class BPETokenizerSimple:
    def __init__(self):
        # Maps token_id to token_str (e.g., {11246: "some"})
        self.vocab = {}
        # Maps token_str to token_id (e.g., {"some": 11246})
        self.inverse_vocab = {}
        # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}
        self.bpe_merges = {}

    def train(self, text, vocab_size, allowed_special={"<|endoftext|>"}):
        """
        Train the BPE tokenizer from scratch.

        Args:
            text (str): The training text.
            vocab_size (int): The desired vocabulary size.
            allowed_special (set): A set of special tokens to include.
        """

        # Preprocess: Replace spaces with 'Ġ'
        # Note that Ġ is a particularity of the GPT-2 BPE implementation
        # E.g., "Hello world" might be tokenized as ["Hello", "Ġworld"]
        # (GPT-4 BPE would tokenize it as ["Hello", " world"])
        processed_text = []
        for i, char in enumerate(text):
            if char == " " and i != 0:
                processed_text.append("Ġ")
            if char != " ":
                processed_text.append(char)
        processed_text = "".join(processed_text)

        # Initialize vocab with unique characters, including 'Ġ' if present
        # Start with the first 256 ASCII characters
        unique_chars = [chr(i) for i in range(256)]

        # Extend unique_chars with characters from processed_text that are not already included
        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)

        # Optionally, ensure 'Ġ' is included if it is relevant to your text processing
        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')

        # Now create the vocab and inverse vocab dictionaries
        self.vocab = {i: char for i, char in enumerate(unique_chars)}
        self.inverse_vocab = {char: i for i, char in self.vocab.items()}

        # Add allowed special tokens
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id = len(self.vocab)
                    self.vocab[new_id] = token
                    self.inverse_vocab[token] = new_id

        # Tokenize the processed_text into token IDs
        token_ids = [self.inverse_vocab[char] for char in processed_text]

        # BPE steps 1-3: Repeatedly find and replace frequent pairs
        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(token_ids, mode="most")
            if pair_id is None:  # No more pairs to merge. Stopping training.
                break
            token_ids = self.replace_pair(token_ids, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

        # Build the vocabulary with merged tokens
        for (p0, p1), new_id in self.bpe_merges.items():
            merged_token = self.vocab[p0] + self.vocab[p1]
            self.vocab[new_id] = merged_token
            self.inverse_vocab[merged_token] = new_id

    def load_vocab_and_merges_from_openai(self, vocab_path, bpe_merges_path):
        """
        Load pre-trained vocabulary and BPE merges from OpenAI's GPT-2 files.

        Args:
            vocab_path (str): Path to the vocab file (GPT-2 calls it 'encoder.json').
            bpe_merges_path (str): Path to the bpe_merges file  (GPT-2 calls it 'vocab.bpe').
        """
        # Load vocabulary
        with open(vocab_path, "r", encoding="utf-8") as file:
            loaded_vocab = json.load(file)
            # loaded_vocab maps token_str to token_id
            self.vocab = {int(v): k for k, v in loaded_vocab.items()}  # token_id: token_str
            self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}  # token_str: token_id

        # Load BPE merges
        with open(bpe_merges_path, "r", encoding="utf-8") as file:
            lines = file.readlines()
            # Skip header line if present
            if lines and lines[0].startswith("#"):
                lines = lines[1:]

            for rank, line in enumerate(lines):
                pair = tuple(line.strip().split())
                if len(pair) != 2:
                    print(f"Line {rank+1} has more than 2 entries: {line.strip()}")
                    continue
                token1, token2 = pair
                if token1 in self.inverse_vocab and token2 in self.inverse_vocab:
                    token_id1 = self.inverse_vocab[token1]
                    token_id2 = self.inverse_vocab[token2]
                    merged_token = token1 + token2
                    if merged_token in self.inverse_vocab:
                        merged_token_id = self.inverse_vocab[merged_token]
                        self.bpe_merges[(token_id1, token_id2)] = merged_token_id
                        # print(f"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})")
                    else:
                        print(f"Merged token '{merged_token}' not found in vocab. Skipping.")
                else:
                    print(f"Skipping pair {pair} as one of the tokens is not in the vocabulary.")

    def encode(self, text):
        """
        Encode the input text into a list of token IDs.

        Args:
            text (str): The text to encode.

        Returns:
            List[int]: The list of token IDs.
        """
        tokens = []
        # Split the text into words based on spaces
        words = text.split(" ")
        for i, word in enumerate(words):
            if i == 0:
                tokens.append(word)
            else:
                tokens.append("Ġ" + word)  # Prepend 'Ġ' to tokens after spaces

        token_ids = []
        for token in tokens:
            if token in self.inverse_vocab:
                # token is contained in the vocabulary as is
                token_id = self.inverse_vocab[token]
                token_ids.append(token_id)
            else:
                # Attempt to handle subword tokenization via BPE
                sub_token_ids = self.tokenize_with_bpe(token)
                token_ids.extend(sub_token_ids)

        return token_ids

    def tokenize_with_bpe(self, token):
        """
        Tokenize a single token using BPE merges.

        Args:
            token (str): The token to tokenize.

        Returns:
            List[int]: The list of token IDs after applying BPE.
        """
        # Tokenize the token into individual characters (as initial token IDs)
        token_ids = [self.inverse_vocab.get(char, None) for char in token]
        if None in token_ids:
            missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]
            raise ValueError(f"Characters not found in vocab: {missing_chars}")

        can_merge = True
        while can_merge and len(token_ids) > 1:
            can_merge = False
            new_tokens = []
            i = 0
            while i < len(token_ids) - 1:
                pair = (token_ids[i], token_ids[i + 1])
                if pair in self.bpe_merges:
                    merged_token_id = self.bpe_merges[pair]
                    new_tokens.append(merged_token_id)
                    # Uncomment for educational purposes:
                    # print(f"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')")
                    i += 2  # Skip the next token as it's merged
                    can_merge = True
                else:
                    new_tokens.append(token_ids[i])
                    i += 1
            if i < len(token_ids):
                new_tokens.append(token_ids[i])
            token_ids = new_tokens

        return token_ids

    def decode(self, token_ids):
        """
        Decode a list of token IDs back into a string.

        Args:
            token_ids (List[int]): The list of token IDs to decode.

        Returns:
            str: The decoded string.
        """
        decoded_string = ""
        for token_id in token_ids:
            if token_id not in self.vocab:
                raise ValueError(f"Token ID {token_id} not found in vocab.")
            token = self.vocab[token_id]
            if token.startswith("Ġ"):
                # Replace 'Ġ' with a space
                decoded_string += " " + token[1:]
            else:
                decoded_string += token
        return decoded_string

    def save_vocab_and_merges(self, vocab_path, bpe_merges_path):
        """
        Save the vocabulary and BPE merges to JSON files.

        Args:
            vocab_path (str): Path to save the vocabulary.
            bpe_merges_path (str): Path to save the BPE merges.
        """
        # Save vocabulary
        with open(vocab_path, "w", encoding="utf-8") as file:
            json.dump({k: v for k, v in self.vocab.items()}, file, ensure_ascii=False, indent=2)

        # Save BPE merges as a list of dictionaries
        with open(bpe_merges_path, "w", encoding="utf-8") as file:
            merges_list = [{"pair": list(pair), "new_id": new_id}
                           for pair, new_id in self.bpe_merges.items()]
            json.dump(merges_list, file, ensure_ascii=False, indent=2)

    def load_vocab_and_merges(self, vocab_path, bpe_merges_path):
        """
        Load the vocabulary and BPE merges from JSON files.

        Args:
            vocab_path (str): Path to the vocabulary file.
            bpe_merges_path (str): Path to the BPE merges file.
        """
        # Load vocabulary
        with open(vocab_path, "r", encoding="utf-8") as file:
            loaded_vocab = json.load(file)
            self.vocab = {int(k): v for k, v in loaded_vocab.items()}
            self.inverse_vocab = {v: int(k) for k, v in loaded_vocab.items()}

        # Load BPE merges
        with open(bpe_merges_path, "r", encoding="utf-8") as file:
            merges_list = json.load(file)
            for merge in merges_list:
                pair = tuple(merge['pair'])
                new_id = merge['new_id']
                self.bpe_merges[pair] = new_id

    @lru_cache(maxsize=None)
    def get_special_token_id(self, token):
        return self.inverse_vocab.get(token, None)

    @staticmethod
    def find_freq_pair(token_ids, mode="most"):
        pairs = Counter(zip(token_ids, token_ids[1:]))

        if mode == "most":
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == "least":
            return min(pairs.items(), key=lambda x: x[1])[0]
        else:
            raise ValueError("Invalid mode. Choose 'most' or 'least'.")

    @staticmethod
    def replace_pair(token_ids, pair_id, new_id):
        dq = deque(token_ids)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                # Remove the 2nd token of the pair, 1st was already removed
                dq.popleft()
            else:
                replaced.append(current)

        return replaced
  • There is a lot of code in the BPETokenizerSimple class above, and discussing it in detail is out of scope for this notebook, but the next section offers a short overview of the usage to understand the class methods a bit better

3. BPE implementation walkthrough

  • In practice, I highly recommend using tiktoken as my implementation above focuses on readability and educational purposes, not on performance
  • However, the usage is more or less similar to tiktoken, except that tiktoken does not have a training method
  • Let’s see how my BPETokenizerSimple Python code above works by looking at some examples below (a detailed code discussion is out of scope for this notebook)

3.1 Training, encoding, and decoding

  • First, let’s consider some sample text as our training dataset:
import os
import urllib.request

if not os.path.exists("the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    file_path = "the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)

with open("the-verdict.txt", "r", encoding="utf-8") as f:
    text = f.read()
  • Next, let’s initialize and train the BPE tokenizer with a vocabulary size of 1,000
  • Note that the vocabulary size is already 255 by default due to the byte values discussed earlier, so we are only “learning” 745 vocabulary entries
  • For comparison, the GPT-2 vocabulary is 50,257 tokens, the GPT-4 vocabulary is 100,256 tokens (cl100k_base in tiktoken), and GPT-4o uses 199,997 tokens (o200k_base in tiktoken); they have all much bigger training sets compared to our simple example text above
tokenizer = BPETokenizerSimple()
tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"})
  • You may want to inspect the vocabulary contents (but note it will create a long list)
# print(tokenizer.vocab)
print(len(tokenizer.vocab))
1000
  • This vocabulary is created by merging 742 times (~ 1000 - len(range(0, 256)))
print(len(tokenizer.bpe_merges))
742
  • This means that the first 256 entries are single-character tokens

  • Next, let’s use the created merges via the encode method to encode some text:

input_text = "Jack embraced beauty through art and life."
token_ids = tokenizer.encode(input_text)
print(token_ids)
[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46]
print("Number of characters:", len(input_text))
print("Number of token IDs:", len(token_ids))
Number of characters: 42
Number of token IDs: 20
  • From the lengths above, we can see that a 42-character sentence was encoded into 20 token IDs, effectively cutting the input length roughly in half compared to a character-byte-based encoding

  • Note that the vocabulary itself is used in the decoder() method, which allows us to map the token IDs back into text:

print(token_ids)
[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46]
print(tokenizer.decode(token_ids))
Jack embraced beauty through art and life.
  • Iterating over each token ID can give us a better understanding of how the token IDs are decoded via the vocabulary:
for token_id in token_ids:
    print(f"{token_id} -> {tokenizer.decode([token_id])}")
424 -> Jack
256 ->
654 -> em
531 -> br
302 -> ac
311 -> ed
256 ->
296 -> be
97 -> a
465 -> ut
121 -> y
595 ->  through
841 ->  ar
116 -> t
287 ->  a
466 -> nd
256 ->
326 -> li
972 -> fe
46 -> .
  • As we can see, most token IDs represent 2-character subwords; that’s because the training data text is very short with not that many repetitive words, and because we used a relatively small vocabulary size

  • As a summary, calling decode(encode()) should be able to reproduce arbitrary input texts:

tokenizer.decode(tokenizer.encode("This is some text."))
'This is some text.'

3.2 Saving and loading the tokenizer

  • Next, let’s look at how we can save the trained tokenizer for reuse later:
# Save trained tokenizer
tokenizer.save_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
# Load tokenizer
tokenizer2 = BPETokenizerSimple()
tokenizer2.load_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
  • The loaded tokenizer should be able to produce the same results as before:
print(tokenizer2.decode(token_ids))
Jack embraced beauty through art and life.

 

3.3 Loading the original GPT-2 BPE tokenizer from OpenAI

  • Finally, let’s load OpenAI’s GPT-2 tokenizer files
import os
import urllib.request

def download_file_if_absent(url, filename):
    if not os.path.exists(filename):
        try:
            with urllib.request.urlopen(url) as response, open(filename, 'wb') as out_file:
                out_file.write(response.read())
            print(f"Downloaded {filename}")
        except Exception as e:
            print(f"Failed to download {filename}. Error: {e}")
    else:
        print(f"{filename} already exists")

files_to_download = {
    "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe": "vocab.bpe",
    "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json": "encoder.json"
}

for url, filename in files_to_download.items():
    download_file_if_absent(url, filename)
vocab.bpe already exists
encoder.json already exists
  • Next, we load the files via the load_vocab_and_merges_from_openai method:
tokenizer_gpt2 = BPETokenizerSimple()
tokenizer_gpt2.load_vocab_and_merges_from_openai(
    vocab_path="encoder.json", bpe_merges_path="vocab.bpe"
)
  • The vocabulary size should be 50257 as we can confirm via the code below:
len(tokenizer_gpt2.vocab)
50257
  • We can now use the GPT-2 tokenizer via our BPETokenizerSimple object:
input_text = "This is some text"
token_ids = tokenizer_gpt2.encode(input_text)
print(token_ids)
[1212, 318, 617, 2420]
print(tokenizer_gpt2.decode(token_ids))
This is some text
import tiktoken

tik_tokenizer = tiktoken.get_encoding("gpt2")
tik_tokenizer.encode("This is some text")
[1212, 318, 617, 2420]

import tiktoken

gpt2_tokenizer = tiktoken.get_encoding("gpt2")
gpt2_tokenizer.encode("This is some text")
# prints [1212, 318, 617, 2420]

 

4. Conclusion

  • That’s it! That’s how BPE works in a nutshell, complete with a training method for creating new tokenizers or loading the GPT-2 tokenizer vocabular and merges from the original OpenAI GPT-2 model
  • I hope you found this brief tutorial useful for educational purposes; if you have any questions, please feel free to open a new Discussion here